diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..bbfc09765 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,41 @@ +# Set the default behavior +* text=auto + +# Go files +*.mod text eol=lf +*.sum text eol=lf +*.go text eol=lf + +# Serialization +*.yml eol=lf +*.yaml eol=lf +*.toml eol=lf +*.json eol=lf + +# Scripts +*.sh eol=lf + +# DB files +*.sql eol=lf + +# Html +*.html eol=lf + +# Text and markdown files +*.txt text eol=lf +*.md text eol=lf + +# Environment files/examples +*.env text eol=lf + +# Docker files +.dockerignore text eol=lf +Dockerfile* text eol=lf + +# Makefile +Makefile text eol=lf + +# Git files +.gitignore text eol=lf +.gitattributes text eol=lf +.gitkeep text eol=lf \ No newline at end of file diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d157515bb..561f0fbb8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @netlify/opensource @netlify/backend +* @supabase/auth diff --git a/.github/workflows/conventional-commits-lint.js b/.github/workflows/conventional-commits-lint.js new file mode 100644 index 000000000..976a92b1b --- /dev/null +++ b/.github/workflows/conventional-commits-lint.js @@ -0,0 +1,109 @@ +"use strict"; + +const fs = require("fs"); + +const TITLE_PATTERN = + /^(?[^:!(]+)(?\([^)]+\))?(?[!])?:.+$/; +const RELEASE_AS_DIRECTIVE = /^\s*Release-As:/im; +const BREAKING_CHANGE_DIRECTIVE = /^\s*BREAKING[ \t]+CHANGE:/im; + +const ALLOWED_CONVENTIONAL_COMMIT_PREFIXES = [ + "revert", + "feat", + "fix", + "ci", + "docs", + "chore", +]; + +const object = process.argv[2]; +const payload = JSON.parse(fs.readFileSync(process.argv[3], "utf-8")); + +let validate = []; + +if (object === "pr") { + validate.push({ + title: payload.pull_request.title, + content: payload.pull_request.body, + }); +} else if (object === "push") { + validate.push( + ...payload.commits + .map((commit) => ({ + title: commit.message.split("\n")[0], + content: commit.message, + })) + .filter(({ title }) => !title.startsWith("Merge branch ") && !title.startsWith("Revert ")), + ); +} else { + console.error( + `Unknown object for first argument "${object}", use 'pr' or 'push'.`, + ); + process.exit(0); +} + +let failed = false; + +validate.forEach((payload) => { + if (payload.title) { + const match = payload.title.match(TITLE_PATTERN); + if (!match) { + return + } + + const { groups } = match + + if (groups) { + if (groups.breaking) { + console.error( + `PRs are not allowed to declare breaking changes at this stage of the project. Please remove the ! in your PR title or commit message and adjust the functionality to be backward compatible.`, + ); + failed = true; + } + + if ( + !ALLOWED_CONVENTIONAL_COMMIT_PREFIXES.find( + (prefix) => prefix === groups.prefix, + ) + ) { + console.error( + `PR (or a commit in it) is using a disallowed conventional commit prefix ("${groups.prefix}"). Only ${ALLOWED_CONVENTIONAL_COMMIT_PREFIXES.join(", ")} are allowed. Make sure the prefix is lowercase!`, + ); + failed = true; + } + + if (groups.package && groups.prefix !== "chore") { + console.warn( + "Avoid using package specifications in PR titles or commits except for the `chore` prefix.", + ); + } + } else { + console.error( + "PR or commit title must match conventional commit structure.", + ); + failed = true; + } + } + + if (payload.content) { + if (payload.content.match(RELEASE_AS_DIRECTIVE)) { + console.error( + "PR descriptions or commit messages must not contain Release-As conventional commit directives.", + ); + failed = true; + } + + if (payload.content.match(BREAKING_CHANGE_DIRECTIVE)) { + console.error( + "PR descriptions or commit messages must not contain a BREAKING CHANGE conventional commit directive. Please adjust the functionality to be backward compatible.", + ); + failed = true; + } + } +}); + +if (failed) { + process.exit(1); +} + +process.exit(0); diff --git a/.github/workflows/conventional-commits.yml b/.github/workflows/conventional-commits.yml new file mode 100644 index 000000000..b2f3499c6 --- /dev/null +++ b/.github/workflows/conventional-commits.yml @@ -0,0 +1,48 @@ +name: Check pull requests + +on: + push: + branches-ignore: # Run the checks on all branches but the protected ones + - master + - release/* + + pull_request: + branches: + - master + - release/* + types: + - opened + - edited + - reopened + - ready_for_review + +permissions: + contents: read + +jobs: + check-conventional-commits: + runs-on: ubuntu-latest + if: github.actor != 'dependabot[bot]' # skip for dependabot PRs + env: + EVENT: ${{ toJSON(github.event) }} + steps: + - uses: actions/checkout@v4 + with: + sparse-checkout: | + .github + + - if: ${{ github.event_name == 'pull_request' }} + run: | + set -ex + + TMP_FILE=$(mktemp) + echo "${EVENT}" > "$TMP_FILE" + node .github/workflows/conventional-commits-lint.js pr "${TMP_FILE}" + + - if: ${{ github.event_name == 'push' }} + run: | + set -ex + + TMP_FILE=$(mktemp) + echo "${EVENT}" > "$TMP_FILE" + node .github/workflows/conventional-commits-lint.js push "${TMP_FILE}" diff --git a/.github/workflows/dogfooding.yml b/.github/workflows/dogfooding.yml new file mode 100644 index 000000000..d040ee03d --- /dev/null +++ b/.github/workflows/dogfooding.yml @@ -0,0 +1,60 @@ +name: Dogfooding Check + +on: + pull_request_review: + types: [submitted, edited] + + pull_request: + types: + - opened + branches: + - '*' + +permissions: + contents: read + +jobs: + check_dogfooding: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + if: github.event.pull_request.base.ref == 'master' && github.event.pull_request.head.ref == 'release-please--branches--master' + with: + ref: master # used to identify the latest RC version via git describe --tags --match rc* + fetch-depth: 0 + + - if: github.event.pull_request.base.ref == 'master' && github.event.pull_request.head.ref == 'release-please--branches--master' + run: | + set -ex + + # finds the latest RC version on master + RELEASE_VERSION=$(node -e "const a = '$(git describe --tags --match rc*)'.replace(/^rc/, 'v').split('-'); console.log(a[0] + '-' + a[1]);") + + PROD_VERSION=$(curl 'https://auth.supabase.io/auth/v1/health' | jq -r .version) + STAGING_VERSION=$(curl 'https://alt.supabase.green/auth/v1/health' | jq -r .version) + + echo "Expecting RC version $RELEASE_VERSION to be up on prod and staging." + + if [ "$PROD_VERSION" != "$STAGING_VERSION" ] + then + echo "Versions on prod and staging don't match!" + + exit 1 + fi + + if [ "$PROD_VERSION" != "$RELEASE_VERSION" ] + then + echo "Version on prod $PROD_VERSION is not the latest release candidate. Please release this RC first to proof the release before merging this PR." + exit 1 + fi + + echo "Release away!" + exit 0 + + - if: github.event.pull_request.base.ref != 'master' || github.event.pull_request.head.ref != 'release-please--branches--master' + run: | + set -ex + + echo "This PR is not subject to dogfooding checks." + exit 0 + diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 000000000..61350593d --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,100 @@ +name: Publish to Image Registry + +on: + workflow_call: + inputs: + version: + required: true + type: string + +permissions: + contents: read + +jobs: + publish: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + id-token: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - id: meta + uses: docker/metadata-action@v4 + with: + images: | + supabase/gotrue + public.ecr.aws/supabase/gotrue + ghcr.io/supabase/gotrue + ghcr.io/supabase/auth + 436098097459.dkr.ecr.us-east-1.amazonaws.com/gotrue + 646182064048.dkr.ecr.us-east-1.amazonaws.com/gotrue + supabase/auth + public.ecr.aws/supabase/auth + 436098097459.dkr.ecr.us-east-1.amazonaws.com/auth + 646182064048.dkr.ecr.us-east-1.amazonaws.com/auth + flavor: | + latest=false + tags: | + type=raw,value=v${{ inputs.version }},enable=true + + - uses: docker/setup-qemu-action@v2 + with: + platforms: amd64,arm64 + + - run: | + set -ex + + echo "Adding explicit release version to Dockerfile..." + + sed -i 's/RELEASE_VERSION=unspecified/RELEASE_VERSION=${{ inputs.version }}/' Dockerfile + + - uses: docker/setup-buildx-action@v2 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: configure aws credentials - prod + uses: aws-actions/configure-aws-credentials@v1 + with: + role-to-assume: ${{ secrets.PROD_AWS_ROLE }} + aws-region: us-east-1 + - name: Login to ECR + uses: docker/login-action@v2 + with: + registry: public.ecr.aws + - name: Login to ECR account - prod + uses: docker/login-action@v2 + with: + registry: 646182064048.dkr.ecr.us-east-1.amazonaws.com + + - name: configure aws credentials - staging + uses: aws-actions/configure-aws-credentials@v1 + with: + role-to-assume: ${{ secrets.DEV_AWS_ROLE }} + aws-region: us-east-1 + - name: Login to ECR account - staging + uses: docker/login-action@v2 + with: + registry: 436098097459.dkr.ecr.us-east-1.amazonaws.com + + - name: Login to GHCR + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - uses: docker/build-push-action@v3 + with: + context: . # IMPORTANT: Dockerfile is modified above to include the release version. Don't remove this line: https://github.com/docker/build-push-action?tab=readme-ov-file#git-context + push: true + platforms: linux/amd64,linux/arm64 + tags: ${{ steps.meta.outputs.tags }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 58fd4a676..7fb0cc0ee 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,110 +1,143 @@ -name: Build and Release on Push to Master +name: Release on: push: branches: - master + - release/* -jobs: - release: - strategy: - matrix: - node: ["16"] +permissions: + contents: read +jobs: + release_please: + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write outputs: - status: ${{ steps.pre-release.outputs.release != steps.post-release.outputs.release }} - - runs-on: ubuntu-18.04 + MAIN_RELEASE_VERSION: ${{ steps.versions.outputs.MAIN_RELEASE_VERSION }} + RELEASE_VERSION: ${{ steps.versions.outputs.RELEASE_VERSION }} + RELEASE_CANDIDATE: ${{ steps.versions.outputs.RELEASE_CANDIDATE }} + RELEASE_NAME: ${{ steps.versions.outputs.RELEASE_NAME }} steps: - - uses: actions/checkout@v2 - - - id: pre-release - uses: pozetroninc/github-action-get-latest-release@master - with: - owner: supabase - repo: gotrue - excludes: prerelease, draft - - - name: Set up Node - uses: actions/setup-node@v2 - with: - node-version: ${{ matrix.node }} - - - name: Release on GitHub - id: github-release - run: npx semantic-release -p \ - @semantic-release/commit-analyzer \ - @semantic-release/github \ - @semantic-release/release-notes-generator - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - id: post-release - uses: pozetroninc/github-action-get-latest-release@master - with: - owner: supabase - repo: gotrue - excludes: prerelease, draft - - deploy: - needs: release - runs-on: ubuntu-18.04 - if: success() && needs.release.outputs.status == 'true' - steps: - - uses: actions/checkout@v2 - - - uses: actions/setup-go@v2 - with: - go-version: "^1.14.4" # The Go version to download (if necessary) and use. - - - run: make deps - - - run: make all - - - id: releases - uses: pozetroninc/github-action-get-latest-release@master - with: - owner: supabase - repo: gotrue - excludes: prerelease, draft - - - run: tar -czvf gotrue-${{ steps.releases.outputs.release }}-x86.tar.gz gotrue migrations/ - - run: mv gotrue-arm64 gotrue - - run: tar -czvf gotrue-${{ steps.releases.outputs.release }}-arm64.tar.gz gotrue migrations/ - - - uses: AButler/upload-release-assets@v2.0 + - uses: googleapis/release-please-action@v4 + id: release with: - files: "gotrue-${{ steps.releases.outputs.release }}*.tar.gz" - release-tag: ${{ steps.releases.outputs.release }} - repo-token: ${{ secrets.GITHUB_TOKEN }} + release-type: go + target-branch: ${{ github.ref_name }} - - name: Upload image to Docker Hub - uses: docker/build-push-action@v1 + - uses: actions/checkout@v4 + if: ${{ steps.release.outputs.release_created == 'true' || steps.release.outputs.prs_created == 'true' }} with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - repository: ${{ secrets.DOCKER_REPO }} - tags: latest,${{ steps.releases.outputs.release }} - - - name: Login to ECR account - staging - uses: docker/login-action@v1 - with: - registry: 436098097459.dkr.ecr.us-east-1.amazonaws.com - username: ${{ secrets.DEV_ACCESS_KEY_ID }} - password: ${{ secrets.DEV_SECRET_ACCESS_KEY }} - - - name: Login to ECR account - prod - uses: docker/login-action@v1 - with: - registry: 646182064048.dkr.ecr.us-east-1.amazonaws.com - username: ${{ secrets.PROD_ACCESS_KEY_ID }} - password: ${{ secrets.PROD_SECRET_ACCESS_KEY }} - - - name: Upload image to ECR - uses: docker/build-push-action@v2 + fetch-depth: 0 + + - if: ${{ steps.release.outputs.release_created == 'true' || steps.release.outputs.prs_created == 'true' }} + id: versions + run: | + set -ex + + RELEASE_CANDIDATE=true + NOT_RELEASE_CANDIDATE='${{ steps.release.outputs.release_created }}' + if [ "$NOT_RELEASE_CANDIDATE" == "true" ] + then + RELEASE_CANDIDATE=false + fi + + MAIN_RELEASE_VERSION=x + RELEASE_VERSION=y + + if [ "$RELEASE_CANDIDATE" == "true" ] + then + # Release please doesn't tell you the candidate version when it + # creates the PR, so we have to take it from the title. + MAIN_RELEASE_VERSION=$(node -e "console.log('${{ steps.release.outputs.pr && fromJSON(steps.release.outputs.pr).title }}'.split(' ').reverse().find(x => x.match(/[0-9]+[.][0-9]+[.][0-9]+/)))") + + # Use git describe tags to identify the number of commits the branch + # is ahead of the most recent non-release-candidate tag, which is + # part of the rc. value. + RELEASE_VERSION=$MAIN_RELEASE_VERSION-rc.$(node -e "console.log('$(git describe --tags --exclude rc*)'.split('-')[1])") + + # release-please only ignores releases that have a form like [A-Z0-9], so prefixing with rc + RELEASE_NAME="rc$RELEASE_VERSION" + else + MAIN_RELEASE_VERSION=${{ steps.release.outputs.major }}.${{ steps.release.outputs.minor }}.${{ steps.release.outputs.patch }} + RELEASE_VERSION="$MAIN_RELEASE_VERSION" + RELEASE_NAME="v$RELEASE_VERSION" + fi + + echo "MAIN_RELEASE_VERSION=${MAIN_RELEASE_VERSION}" >> "${GITHUB_ENV}" + echo "RELEASE_VERSION=${RELEASE_VERSION}" >> "${GITHUB_ENV}" + echo "RELEASE_CANDIDATE=${RELEASE_CANDIDATE}" >> "${GITHUB_ENV}" + echo "RELEASE_NAME=${RELEASE_NAME}" >> "${GITHUB_ENV}" + + echo "MAIN_RELEASE_VERSION=${MAIN_RELEASE_VERSION}" >> "${GITHUB_OUTPUT}" + echo "RELEASE_VERSION=${RELEASE_VERSION}" >> "${GITHUB_OUTPUT}" + echo "RELEASE_CANDIDATE=${RELEASE_CANDIDATE}" >> "${GITHUB_OUTPUT}" + echo "RELEASE_NAME=${RELEASE_NAME}" >> "${GITHUB_OUTPUT}" + + - uses: actions/setup-go@v5 + if: ${{ steps.release.outputs.release_created == 'true' || steps.release.outputs.prs_created == 'true' }} with: - context: . - push: true - tags: | - 436098097459.dkr.ecr.us-east-1.amazonaws.com/gotrue:${{ steps.releases.outputs.release }} - 646182064048.dkr.ecr.us-east-1.amazonaws.com/gotrue:${{ steps.releases.outputs.release }} + go-version: "1.23.7" # The Go version to download (if necessary) and use. + + - name: Build release artifacts + if: ${{ steps.release.outputs.release_created == 'true' || steps.release.outputs.prs_created == 'true' }} + run: | + set -ex + + RELEASE_VERSION=$RELEASE_VERSION make deps + RELEASE_VERSION=$RELEASE_VERSION make all + ln -s auth gotrue + tar -czvf auth-v$RELEASE_VERSION-x86.tar.gz auth gotrue migrations/ + mv auth-arm64 auth + tar -czvf auth-v$RELEASE_VERSION-arm64.tar.gz auth gotrue migrations/ + + - name: Upload release artifacts + if: ${{ steps.release.outputs.release_created == 'true' || steps.release.outputs.prs_created == 'true' }} + run: | + set -ex + + if [ "$RELEASE_CANDIDATE" == "true" ] + then + PR_NUMBER='${{ steps.release.outputs.pr && fromJSON(steps.release.outputs.pr).number }}' + + GH_TOKEN='${{ github.token }}' gh release \ + create $RELEASE_NAME \ + --title "v$RELEASE_VERSION" \ + --prerelease \ + -n "This is a release candidate. See release-please PR #$PR_NUMBER for context." + + GH_TOKEN='${{ github.token }}' gh pr comment "$PR_NUMBER" \ + -b "Release candidate [v$RELEASE_VERSION](https://github.com/supabase/gotrue/releases/tag/$RELEASE_NAME) published." + else + if [ "$GITHUB_REF" == "refs/heads/main" ] || [ "$GITHUB_REF" == "refs/heads/master" ] + then + IS_PATCH_ZERO=$(node -e "console.log('$RELEASE_VERSION'.endsWith('.0'))") + + if [ "$IS_PATCH_ZERO" == "true" ] + then + # Only create release branch if patch version is 0, as this + # means that the release can be patched in the future. + + GH_TOKEN='${{ github.token }}' gh api \ + --method POST \ + -H "Accept: application/vnd.github+json" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + /repos/supabase/gotrue/git/refs \ + -f "ref=refs/heads/release/${RELEASE_VERSION}" \ + -f "sha=$GITHUB_SHA" + fi + fi + fi + + GH_TOKEN='${{ github.token }}' gh release upload $RELEASE_NAME ./auth-v$RELEASE_VERSION-x86.tar.gz ./auth-v$RELEASE_VERSION-arm64.tar.gz + + publish: + needs: + - release_please + if: ${{ success() && needs.release_please.outputs.RELEASE_VERSION }} + uses: ./.github/workflows/publish.yml + secrets: inherit + with: + version: ${{ needs.release_please.outputs.RELEASE_VERSION }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dd0477171..e2b60e469 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,18 +3,22 @@ name: Test on: pull_request: push: - branches: [master] - tags: ['*'] + branches: + - master + tags: ["*"] + +permissions: + contents: read jobs: test: strategy: matrix: - go-version: [1.16.x] - runs-on: ubuntu-18.04 + go-version: [1.23.7] + runs-on: ubuntu-20.04 services: postgres: - image: postgres:13 + image: postgres:15 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: root @@ -28,14 +32,53 @@ jobs: --health-retries 5 steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - name: Checkout code uses: actions/checkout@v2 + - name: Check gofmt + run: | + set -x + + if [ ! -z $(gofmt -l .) ] + then + echo 'Make sure to run "gofmt -s -w ." before commit!' && exit 1 + fi + - name: Check go vet + run: | + set -x + go vet ./... + - name: Run static check + run: | + set -x + go install honnef.co/go/tools/cmd/staticcheck@latest + go install github.com/nishanths/exhaustive/cmd/exhaustive@latest + make static + - name: Check gosec + run: | + set -x + go install github.com/securego/gosec/v2/cmd/gosec@latest + make sec - name: Init Database run: psql -f hack/init_postgres.sql postgresql://postgres:root@localhost:5432/postgres - name: Run migrations run: make migrate_dev - name: Lint and test run: make test + - name: Cleanup coverage + run: | + set -x + + # since Go 1.20 these source files need to be deleted from the + # coverage profile as they contain legacy or untestable code (like + # `main` package) + + sed -i '/^github.com\/supabase\/auth\/client/d' coverage.out + sed -i '/^github.com\/supabase\/auth\/cmd/d' coverage.out + sed -i '/^github.com\/supabase\/auth\/docs/d' coverage.out + sed -i '/^github.com\/supabase\/auth\/main/d' coverage.out + + - uses: shogo82148/actions-goveralls@v1 + with: + path-to-profile: coverage.out diff --git a/.gitignore b/.gitignore index a249cc999..acab1be45 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,12 @@ vendor/ gotrue gotrue-arm64 +gotrue.exe +auth +auth-arm64 +auth.exe + +coverage.out .DS_Store .vscode diff --git a/.releaserc b/.releaserc new file mode 100644 index 000000000..32f0e4516 --- /dev/null +++ b/.releaserc @@ -0,0 +1,10 @@ +{ + "branches": [ + "master" + ], + "plugins": [ + "@semantic-release/commit-analyzer", + "@semantic-release/release-notes-generator", + "@semantic-release/github" + ] +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..c00dfab88 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,618 @@ +# Changelog + +## [2.170.0](https://github.com/supabase/auth/compare/v2.169.0...v2.170.0) (2025-03-06) + + +### Features + +* improvements to config reloader, 100% coverage ([#1933](https://github.com/supabase/auth/issues/1933)) ([21c2256](https://github.com/supabase/auth/commit/21c2256806ab4950e9bfc0af0472a64f7d9112a7)) +* increase test coverage in conf package to 100% ([#1937](https://github.com/supabase/auth/issues/1937)) ([bc57c1c](https://github.com/supabase/auth/commit/bc57c1c25769905b29bfc9e89bf3d6b65b1030ea)) + + +### Bug Fixes + +* enable SO_REUSEPORT in listener config ([#1936](https://github.com/supabase/auth/issues/1936)) ([a474b80](https://github.com/supabase/auth/commit/a474b80cc1075eb32a7e72a05b0cdb561e61770b)) +* ignore not found error to check for pkce prefix later ([#1929](https://github.com/supabase/auth/issues/1929)) ([fbbebcc](https://github.com/supabase/auth/commit/fbbebccd5da21ea22323e6f8f853df9168c4c41e)) +* log version & migration count ([#1934](https://github.com/supabase/auth/issues/1934)) ([8078cdc](https://github.com/supabase/auth/commit/8078cdc6f275c97d84c0ba20963327af900b84d0)) +* update figma token endpoint ([#1952](https://github.com/supabase/auth/issues/1952)) ([18fbbb5](https://github.com/supabase/auth/commit/18fbbb53de04c024b6de829e390145a8452d7ab2)) +* use sys/unix instead of syscall ([#1953](https://github.com/supabase/auth/issues/1953)) ([4a6d9bc](https://github.com/supabase/auth/commit/4a6d9bcade28db3c7a6c2c610600665190c9a925)) + +## [2.169.0](https://github.com/supabase/auth/compare/v2.168.0...v2.169.0) (2025-01-27) + + +### Features + +* add an optional burstable rate limiter ([#1924](https://github.com/supabase/auth/issues/1924)) ([1f06f58](https://github.com/supabase/auth/commit/1f06f58e1434b91612c0d96c8c0435d26570f3e2)) +* cover 100% of crypto with tests ([#1892](https://github.com/supabase/auth/issues/1892)) ([174198e](https://github.com/supabase/auth/commit/174198e56f8e9b8470a717d0021c626130288d2e)) + + +### Bug Fixes + +* convert refreshed_at to UTC before updating ([#1916](https://github.com/supabase/auth/issues/1916)) ([a4c692f](https://github.com/supabase/auth/commit/a4c692f6cb1b8bf4c47ea012872af5ce93382fbf)) +* correct casing of API key authentication in openapi.yaml ([0cfd177](https://github.com/supabase/auth/commit/0cfd177b8fb1df8f62e84fbd3761ef9f90c384de)) +* improve invalid channel error message returned ([#1908](https://github.com/supabase/auth/issues/1908)) ([f72f0ee](https://github.com/supabase/auth/commit/f72f0eee328fa0aa041155f5f5dc305f0874d2bf)) +* improve saml assertion logging ([#1915](https://github.com/supabase/auth/issues/1915)) ([d6030cc](https://github.com/supabase/auth/commit/d6030ccd271a381e2a6ababa11a5beae4b79e5c3)) + +## [2.168.0](https://github.com/supabase/auth/compare/v2.167.0...v2.168.0) (2025-01-06) + + +### Features + +* set `email_verified` to true on all identities with the verified email ([#1902](https://github.com/supabase/auth/issues/1902)) ([307892f](https://github.com/supabase/auth/commit/307892f85b39150074fbb80b9c8f45ac3312aae2)) + +## [2.167.0](https://github.com/supabase/auth/compare/v2.166.0...v2.167.0) (2024-12-24) + + +### Features + +* fix argon2 parsing and comparison ([#1887](https://github.com/supabase/auth/issues/1887)) ([9dbe6ef](https://github.com/supabase/auth/commit/9dbe6ef931ae94e621d55a5f7aea4b7ee0449949)) + +## [2.166.0](https://github.com/supabase/auth/compare/v2.165.0...v2.166.0) (2024-12-23) + + +### Features + +* switch to googleapis/release-please-action, bump to 2.166.0 ([#1883](https://github.com/supabase/auth/issues/1883)) ([11a312f](https://github.com/supabase/auth/commit/11a312fcf77771b3732f2f439078225895df7a85)) + + +### Bug Fixes + +* check if session is nil ([#1873](https://github.com/supabase/auth/issues/1873)) ([fd82601](https://github.com/supabase/auth/commit/fd82601917adcd9f8c38263953eb1ef098b26b7f)) +* email_verified field not being updated on signup confirmation ([#1868](https://github.com/supabase/auth/issues/1868)) ([483463e](https://github.com/supabase/auth/commit/483463e49eec7b2974cca05eadca6b933b2145b5)) +* handle user banned error code ([#1851](https://github.com/supabase/auth/issues/1851)) ([a6918f4](https://github.com/supabase/auth/commit/a6918f49baee42899b3ae1b7b6bc126d84629c99)) +* Revert "fix: revert fallback on btree indexes when hash is unavailable" ([#1859](https://github.com/supabase/auth/issues/1859)) ([9fe5b1e](https://github.com/supabase/auth/commit/9fe5b1eebfafb385d6b5d10196aeb2a1964ab296)) +* skip cleanup for non-2xx status ([#1877](https://github.com/supabase/auth/issues/1877)) ([f572ced](https://github.com/supabase/auth/commit/f572ced3699c7f920deccce1a3539299541ec94c)) + +## [2.165.1](https://github.com/supabase/auth/compare/v2.165.0...v2.165.1) (2024-12-06) + + +### Bug Fixes + +* allow setting the mailer service headers as strings ([#1861](https://github.com/supabase/auth/issues/1861)) ([7907b56](https://github.com/supabase/auth/commit/7907b566228f7e2d76049b44cfe0cc808c109100)) + +## [2.165.0](https://github.com/supabase/auth/compare/v2.164.0...v2.165.0) (2024-12-05) + + +### Features + +* add email validation function to lower bounce rates ([#1845](https://github.com/supabase/auth/issues/1845)) ([2c291f0](https://github.com/supabase/auth/commit/2c291f0356f3e91063b6b43bf2a21625b0ce0ebd)) +* use embedded migrations for `migrate` command ([#1843](https://github.com/supabase/auth/issues/1843)) ([e358da5](https://github.com/supabase/auth/commit/e358da5f0e267725a77308461d0a4126436fc537)) + + +### Bug Fixes + +* fallback on btree indexes when hash is unavailable ([#1856](https://github.com/supabase/auth/issues/1856)) ([b33bc31](https://github.com/supabase/auth/commit/b33bc31c07549dc9dc221100995d6f6b6754fd3a)) +* return the error code instead of status code ([#1855](https://github.com/supabase/auth/issues/1855)) ([834a380](https://github.com/supabase/auth/commit/834a380d803ae9ce59ce5ee233fa3a78a984fe68)) +* revert fallback on btree indexes when hash is unavailable ([#1858](https://github.com/supabase/auth/issues/1858)) ([1c7202f](https://github.com/supabase/auth/commit/1c7202ff835856562ee66b33be131eca769acf1d)) +* update ip mismatch error message ([#1849](https://github.com/supabase/auth/issues/1849)) ([49fbbf0](https://github.com/supabase/auth/commit/49fbbf03917a1085c58e9a1ff76c247ae6bb9ca7)) + +## [2.164.0](https://github.com/supabase/auth/compare/v2.163.2...v2.164.0) (2024-11-13) + + +### Features + +* return validation failed error if captcha request was not json ([#1815](https://github.com/supabase/auth/issues/1815)) ([26d2e36](https://github.com/supabase/auth/commit/26d2e36bba29eb8a6ddba556acfd0820f3bfde5d)) + + +### Bug Fixes + +* add error codes to refresh token flow ([#1824](https://github.com/supabase/auth/issues/1824)) ([4614dc5](https://github.com/supabase/auth/commit/4614dc54ab1dcb5390cfed05441e7888af017d92)) +* add test coverage for rate limits with 0 permitted events ([#1834](https://github.com/supabase/auth/issues/1834)) ([7c3cf26](https://github.com/supabase/auth/commit/7c3cf26cfe2a3e4de579d10509945186ad719855)) +* correct web authn aaguid column naming ([#1826](https://github.com/supabase/auth/issues/1826)) ([0a589d0](https://github.com/supabase/auth/commit/0a589d04e1cd9310cb260d329bc8beb050adf8da)) +* default to files:read scope for Figma provider ([#1831](https://github.com/supabase/auth/issues/1831)) ([9ce2857](https://github.com/supabase/auth/commit/9ce28570bf3da9571198d44d693c7ad7038cde33)) +* improve error messaging for http hooks ([#1821](https://github.com/supabase/auth/issues/1821)) ([fa020d0](https://github.com/supabase/auth/commit/fa020d0fc292d5c381c57ecac6666d9ff657e4c4)) +* make drop_uniqueness_constraint_on_phone idempotent ([#1817](https://github.com/supabase/auth/issues/1817)) ([158e473](https://github.com/supabase/auth/commit/158e4732afa17620cdd89c85b7b57569feea5c21)) +* possible panic if refresh token has a null session_id ([#1822](https://github.com/supabase/auth/issues/1822)) ([a7129df](https://github.com/supabase/auth/commit/a7129df4e1d91a042b56ff1f041b9c6598825475)) +* rate limits of 0 take precedence over MAILER_AUTO_CONFIRM ([#1837](https://github.com/supabase/auth/issues/1837)) ([cb7894e](https://github.com/supabase/auth/commit/cb7894e1119d27d527dedcca22d8b3d433beddac)) + +## [2.163.2](https://github.com/supabase/auth/compare/v2.163.1...v2.163.2) (2024-10-22) + + +### Bug Fixes + +* ignore rate limits for autoconfirm ([#1810](https://github.com/supabase/auth/issues/1810)) ([9ce2340](https://github.com/supabase/auth/commit/9ce23409f960a8efa55075931138624cb681eca5)) + +## [2.163.1](https://github.com/supabase/auth/compare/v2.163.0...v2.163.1) (2024-10-22) + + +### Bug Fixes + +* external host validation ([#1808](https://github.com/supabase/auth/issues/1808)) ([4f6a461](https://github.com/supabase/auth/commit/4f6a4617074e61ba3b31836ccb112014904ce97c)), closes [#1228](https://github.com/supabase/auth/issues/1228) + +## [2.163.0](https://github.com/supabase/auth/compare/v2.162.2...v2.163.0) (2024-10-15) + + +### Features + +* add mail header support via `GOTRUE_SMTP_HEADERS` with `$messageType` ([#1804](https://github.com/supabase/auth/issues/1804)) ([99d6a13](https://github.com/supabase/auth/commit/99d6a134c44554a8ad06695e1dff54c942c8335d)) +* add MFA for WebAuthn ([#1775](https://github.com/supabase/auth/issues/1775)) ([8cc2f0e](https://github.com/supabase/auth/commit/8cc2f0e14d06d0feb56b25a0278fda9e213b6b5a)) +* configurable email and sms rate limiting ([#1800](https://github.com/supabase/auth/issues/1800)) ([5e94047](https://github.com/supabase/auth/commit/5e9404717e1c962ab729cde150ef5b40ea31a6e8)) +* mailer logging ([#1805](https://github.com/supabase/auth/issues/1805)) ([9354b83](https://github.com/supabase/auth/commit/9354b83a48a3edcb49197c997a1e96efc80c5383)) +* preserve rate limiters in memory across configuration reloads ([#1792](https://github.com/supabase/auth/issues/1792)) ([0a3968b](https://github.com/supabase/auth/commit/0a3968b02b9f044bfb7e5ebc71dca970d2bb7807)) + + +### Bug Fixes + +* add twilio verify support on mfa ([#1714](https://github.com/supabase/auth/issues/1714)) ([aeb5d8f](https://github.com/supabase/auth/commit/aeb5d8f8f18af60ce369cab5714979ac0c208308)) +* email header setting no longer misleading ([#1802](https://github.com/supabase/auth/issues/1802)) ([3af03be](https://github.com/supabase/auth/commit/3af03be6b65c40f3f4f62ce9ab989a20d75ae53a)) +* enforce authorized address checks on send email only ([#1806](https://github.com/supabase/auth/issues/1806)) ([c0c5b23](https://github.com/supabase/auth/commit/c0c5b23728c8fb633dae23aa4b29ed60e2691a2b)) +* fix `getExcludedColumns` slice allocation ([#1788](https://github.com/supabase/auth/issues/1788)) ([7f006b6](https://github.com/supabase/auth/commit/7f006b63c8d7e28e55a6d471881e9c118df80585)) +* Fix reqPath for bypass check for verify EP ([#1789](https://github.com/supabase/auth/issues/1789)) ([646dc66](https://github.com/supabase/auth/commit/646dc66ea8d59a7f78bf5a5e55d9b5065a718c23)) +* inline mailme package for easy development ([#1803](https://github.com/supabase/auth/issues/1803)) ([fa6f729](https://github.com/supabase/auth/commit/fa6f729a027eff551db104550fa626088e00bc15)) + +## [2.162.2](https://github.com/supabase/auth/compare/v2.162.1...v2.162.2) (2024-10-05) + + +### Bug Fixes + +* refactor mfa validation into functions ([#1780](https://github.com/supabase/auth/issues/1780)) ([410b8ac](https://github.com/supabase/auth/commit/410b8acdd659fc4c929fe57a9e9dba4c76da305d)) +* upgrade ci Go version ([#1782](https://github.com/supabase/auth/issues/1782)) ([97a48f6](https://github.com/supabase/auth/commit/97a48f6daaa2edda5b568939cbb1007ccdf33cfc)) +* validateEmail should normalise emails ([#1790](https://github.com/supabase/auth/issues/1790)) ([2e9b144](https://github.com/supabase/auth/commit/2e9b144a0cbf2d26d3c4c2eafbff1899a36aeb3b)) + +## [2.162.1](https://github.com/supabase/auth/compare/v2.162.0...v2.162.1) (2024-10-03) + + +### Bug Fixes + +* bypass check for token & verify endpoints ([#1785](https://github.com/supabase/auth/issues/1785)) ([9ac2ea0](https://github.com/supabase/auth/commit/9ac2ea0180826cd2f65e679524aabfb10666e973)) + +## [2.162.0](https://github.com/supabase/auth/compare/v2.161.0...v2.162.0) (2024-09-27) + + +### Features + +* add support for migration of firebase scrypt passwords ([#1768](https://github.com/supabase/auth/issues/1768)) ([ba00f75](https://github.com/supabase/auth/commit/ba00f75c28d6708ddf8ee151ce18f2d6193689ef)) + + +### Bug Fixes + +* apply authorized email restriction to non-admin routes ([#1778](https://github.com/supabase/auth/issues/1778)) ([1af203f](https://github.com/supabase/auth/commit/1af203f92372e6db12454a0d319aad8ce3d149e7)) +* magiclink failing due to passwordStrength check ([#1769](https://github.com/supabase/auth/issues/1769)) ([7a5411f](https://github.com/supabase/auth/commit/7a5411f1d4247478f91027bc4969cbbe95b7774c)) + +## [2.161.0](https://github.com/supabase/auth/compare/v2.160.0...v2.161.0) (2024-09-24) + + +### Features + +* add `x-sb-error-code` header, show error code in logs ([#1765](https://github.com/supabase/auth/issues/1765)) ([ed91c59](https://github.com/supabase/auth/commit/ed91c59aa332738bd0ac4b994aeec2cdf193a068)) +* add webauthn configuration variables ([#1773](https://github.com/supabase/auth/issues/1773)) ([77d5897](https://github.com/supabase/auth/commit/77d58976ae624dbb7f8abee041dd4557aab81109)) +* config reloading ([#1771](https://github.com/supabase/auth/issues/1771)) ([6ee0091](https://github.com/supabase/auth/commit/6ee009163bfe451e2a0b923705e073928a12c004)) + + +### Bug Fixes + +* add additional information around errors for missing content type header ([#1576](https://github.com/supabase/auth/issues/1576)) ([c2b2f96](https://github.com/supabase/auth/commit/c2b2f96f07c97c15597cd972b1cd672238d87cdc)) +* add token to hook payload for non-secure email change ([#1763](https://github.com/supabase/auth/issues/1763)) ([7e472ad](https://github.com/supabase/auth/commit/7e472ad72042e86882dab3fddce9fafa66a8236c)) +* update aal requirements to update user ([#1766](https://github.com/supabase/auth/issues/1766)) ([25d9874](https://github.com/supabase/auth/commit/25d98743f6cc2cca2b490a087f468c8556ec5e44)) +* update mfa admin methods ([#1774](https://github.com/supabase/auth/issues/1774)) ([567ea7e](https://github.com/supabase/auth/commit/567ea7ebd18eacc5e6daea8adc72e59e94459991)) +* user sanitization should clean up email change info too ([#1759](https://github.com/supabase/auth/issues/1759)) ([9d419b4](https://github.com/supabase/auth/commit/9d419b400f0637b10e5c235b8fd5bac0d69352bd)) + +## [2.160.0](https://github.com/supabase/auth/compare/v2.159.2...v2.160.0) (2024-09-02) + + +### Features + +* add authorized email address support ([#1757](https://github.com/supabase/auth/issues/1757)) ([f3a28d1](https://github.com/supabase/auth/commit/f3a28d182d193cf528cc72a985dfeaf7ecb67056)) +* add option to disable magic links ([#1756](https://github.com/supabase/auth/issues/1756)) ([2ad0737](https://github.com/supabase/auth/commit/2ad07373aa9239eba94abdabbb01c9abfa8c48de)) +* add support for saml encrypted assertions ([#1752](https://github.com/supabase/auth/issues/1752)) ([c5480ef](https://github.com/supabase/auth/commit/c5480ef83248ec2e7e3d3d87f92f43f17161ed25)) + + +### Bug Fixes + +* apply shared limiters before email / sms is sent ([#1748](https://github.com/supabase/auth/issues/1748)) ([bf276ab](https://github.com/supabase/auth/commit/bf276ab49753642793471815727559172fea4efc)) +* simplify WaitForCleanup ([#1747](https://github.com/supabase/auth/issues/1747)) ([0084625](https://github.com/supabase/auth/commit/0084625ad0790dd7c14b412d932425f4b84bb4c8)) + +## [2.159.2](https://github.com/supabase/auth/compare/v2.159.1...v2.159.2) (2024-08-28) + + +### Bug Fixes + +* allow anonymous user to update password ([#1739](https://github.com/supabase/auth/issues/1739)) ([2d51956](https://github.com/supabase/auth/commit/2d519569d7b8540886d0a64bf3e561ef5f91eb63)) +* hide hook name ([#1743](https://github.com/supabase/auth/issues/1743)) ([7e38f4c](https://github.com/supabase/auth/commit/7e38f4cf37768fe2adf92bbd0723d1d521b3d74c)) +* remove server side cookie token methods ([#1742](https://github.com/supabase/auth/issues/1742)) ([c6efec4](https://github.com/supabase/auth/commit/c6efec4cbc950e01e1fd06d45ed821bd27c2ad08)) + +## [2.159.1](https://github.com/supabase/auth/compare/v2.159.0...v2.159.1) (2024-08-23) + + +### Bug Fixes + +* return oauth identity when user is created ([#1736](https://github.com/supabase/auth/issues/1736)) ([60cfb60](https://github.com/supabase/auth/commit/60cfb6063afa574dfe4993df6b0e087d4df71309)) + +## [2.159.0](https://github.com/supabase/auth/compare/v2.158.1...v2.159.0) (2024-08-21) + + +### Features + +* Vercel marketplace OIDC ([#1731](https://github.com/supabase/auth/issues/1731)) ([a9ff361](https://github.com/supabase/auth/commit/a9ff3612196af4a228b53a8bfb9c11785bcfba8d)) + + +### Bug Fixes + +* add error codes to password login flow ([#1721](https://github.com/supabase/auth/issues/1721)) ([4351226](https://github.com/supabase/auth/commit/435122627a0784f1c5cb76d7e08caa1f6259423b)) +* change phone constraint to per user ([#1713](https://github.com/supabase/auth/issues/1713)) ([b9bc769](https://github.com/supabase/auth/commit/b9bc769b93b6e700925fcbc1ebf8bf9678034205)) +* custom SMS does not work with Twilio Verify ([#1733](https://github.com/supabase/auth/issues/1733)) ([dc2391d](https://github.com/supabase/auth/commit/dc2391d15f2c0725710aa388cd32a18797e6769c)) +* ignore errors if transaction has closed already ([#1726](https://github.com/supabase/auth/issues/1726)) ([53c11d1](https://github.com/supabase/auth/commit/53c11d173a79ae5c004871b1b5840c6f9425a080)) +* redirect invalid state errors to site url ([#1722](https://github.com/supabase/auth/issues/1722)) ([b2b1123](https://github.com/supabase/auth/commit/b2b11239dc9f9bd3c85d76f6c23ee94beb3330bb)) +* remove TOTP field for phone enroll response ([#1717](https://github.com/supabase/auth/issues/1717)) ([4b04327](https://github.com/supabase/auth/commit/4b043275dd2d94600a8138d4ebf4638754ed926b)) +* use signing jwk to sign oauth state ([#1728](https://github.com/supabase/auth/issues/1728)) ([66fd0c8](https://github.com/supabase/auth/commit/66fd0c8434388bbff1e1bf02f40517aca0e9d339)) + +## [2.158.1](https://github.com/supabase/auth/compare/v2.158.0...v2.158.1) (2024-08-05) + + +### Bug Fixes + +* add last_challenged_at field to mfa factors ([#1705](https://github.com/supabase/auth/issues/1705)) ([29cbeb7](https://github.com/supabase/auth/commit/29cbeb799ff35ce528bfbd01b7103a24903d8061)) +* allow enabling sms hook without setting up sms provider ([#1704](https://github.com/supabase/auth/issues/1704)) ([575e88a](https://github.com/supabase/auth/commit/575e88ac345adaeb76ab6aae077307fdab9cda3c)) +* drop the MFA_ENABLED config ([#1701](https://github.com/supabase/auth/issues/1701)) ([078c3a8](https://github.com/supabase/auth/commit/078c3a8adcd51e57b68ab1b582549f5813cccd14)) +* enforce uniqueness on verified phone numbers ([#1693](https://github.com/supabase/auth/issues/1693)) ([70446cc](https://github.com/supabase/auth/commit/70446cc11d70b0493d742fe03f272330bb5b633e)) +* expose `X-Supabase-Api-Version` header in CORS ([#1612](https://github.com/supabase/auth/issues/1612)) ([6ccd814](https://github.com/supabase/auth/commit/6ccd814309dca70a9e3585543887194b05d725d3)) +* include factor_id in query ([#1702](https://github.com/supabase/auth/issues/1702)) ([ac14e82](https://github.com/supabase/auth/commit/ac14e82b33545466184da99e99b9d3fe5f3876d9)) +* move is owned by check to load factor ([#1703](https://github.com/supabase/auth/issues/1703)) ([701a779](https://github.com/supabase/auth/commit/701a779cf092e777dd4ad4954dc650164b09ab32)) +* refactor TOTP MFA into separate methods ([#1698](https://github.com/supabase/auth/issues/1698)) ([250d92f](https://github.com/supabase/auth/commit/250d92f9a18d38089d1bf262ef9088022a446965)) +* remove check for content-length ([#1700](https://github.com/supabase/auth/issues/1700)) ([81b332d](https://github.com/supabase/auth/commit/81b332d2f48622008469d2c5a9b130465a65f2a3)) +* remove FindFactorsByUser ([#1707](https://github.com/supabase/auth/issues/1707)) ([af8e2dd](https://github.com/supabase/auth/commit/af8e2dda15a1234a05e7d2d34d316eaa029e0912)) +* update openapi spec for MFA (Phone) ([#1689](https://github.com/supabase/auth/issues/1689)) ([a3da4b8](https://github.com/supabase/auth/commit/a3da4b89820c37f03ea128889616aca598d99f68)) + +## [2.158.0](https://github.com/supabase/auth/compare/v2.157.0...v2.158.0) (2024-07-31) + + +### Features + +* add hook log entry with `run_hook` action ([#1684](https://github.com/supabase/auth/issues/1684)) ([46491b8](https://github.com/supabase/auth/commit/46491b867a4f5896494417391392a373a453fa5f)) +* MFA (Phone) ([#1668](https://github.com/supabase/auth/issues/1668)) ([ae091aa](https://github.com/supabase/auth/commit/ae091aa942bdc5bc97481037508ec3bb4079d859)) + + +### Bug Fixes + +* maintain backward compatibility for asymmetric JWTs ([#1690](https://github.com/supabase/auth/issues/1690)) ([0ad1402](https://github.com/supabase/auth/commit/0ad1402444348e47e1e42be186b3f052d31be824)) +* MFA NewFactor to default to creating unverfied factors ([#1692](https://github.com/supabase/auth/issues/1692)) ([3d448fa](https://github.com/supabase/auth/commit/3d448fa73cb77eb8511dbc47bfafecce4a4a2150)) +* minor spelling errors ([#1688](https://github.com/supabase/auth/issues/1688)) ([6aca52b](https://github.com/supabase/auth/commit/6aca52b56f8a6254de7709c767b9a5649f1da248)), closes [#1682](https://github.com/supabase/auth/issues/1682) +* treat `GOTRUE_MFA_ENABLED` as meaning TOTP enabled on enroll and verify ([#1694](https://github.com/supabase/auth/issues/1694)) ([8015251](https://github.com/supabase/auth/commit/8015251400bd52cbdad3ea28afb83b1cdfe816dd)) +* update mfa phone migration to be idempotent ([#1687](https://github.com/supabase/auth/issues/1687)) ([fdff1e7](https://github.com/supabase/auth/commit/fdff1e703bccf93217636266f1862bd0a9205edb)) + +## [2.157.0](https://github.com/supabase/auth/compare/v2.156.0...v2.157.0) (2024-07-26) + + +### Features + +* add asymmetric jwt support ([#1674](https://github.com/supabase/auth/issues/1674)) ([c7a2be3](https://github.com/supabase/auth/commit/c7a2be347b301b666e99adc3d3fed78c5e287c82)) + +## [2.156.0](https://github.com/supabase/auth/compare/v2.155.6...v2.156.0) (2024-07-25) + + +### Features + +* add is_anonymous claim to Auth hook jsonschema ([#1667](https://github.com/supabase/auth/issues/1667)) ([f9df65c](https://github.com/supabase/auth/commit/f9df65c91e226084abfa2e868ab6bab892d16d2f)) + + +### Bug Fixes + +* restrict autoconfirm email change to anonymous users ([#1679](https://github.com/supabase/auth/issues/1679)) ([b57e223](https://github.com/supabase/auth/commit/b57e2230102280ed873acf70be1aeb5a2f6f7a4f)) + +## [2.155.6](https://github.com/supabase/auth/compare/v2.155.5...v2.155.6) (2024-07-22) + + +### Bug Fixes + +* use deep equal ([#1672](https://github.com/supabase/auth/issues/1672)) ([8efd57d](https://github.com/supabase/auth/commit/8efd57dab40346762a04bac61b314ce05d6fa69c)) + +## [2.155.5](https://github.com/supabase/auth/compare/v2.155.4...v2.155.5) (2024-07-19) + + +### Bug Fixes + +* check password max length in checkPasswordStrength ([#1659](https://github.com/supabase/auth/issues/1659)) ([1858c93](https://github.com/supabase/auth/commit/1858c93bba6f5bc41e4c65489f12c1a0786a1f2b)) +* don't update attribute mapping if nil ([#1665](https://github.com/supabase/auth/issues/1665)) ([7e67f3e](https://github.com/supabase/auth/commit/7e67f3edbf81766df297a66f52a8e472583438c6)) +* refactor mfa models and add observability to loadFactor ([#1669](https://github.com/supabase/auth/issues/1669)) ([822fb93](https://github.com/supabase/auth/commit/822fb93faab325ba3d4bb628dff43381d68d0b5d)) + +## [2.155.4](https://github.com/supabase/auth/compare/v2.155.3...v2.155.4) (2024-07-17) + + +### Bug Fixes + +* treat empty string as nil in `encrypted_password` ([#1663](https://github.com/supabase/auth/issues/1663)) ([f99286e](https://github.com/supabase/auth/commit/f99286eaed505daf3db6f381265ef6024e7e36d2)) + +## [2.155.3](https://github.com/supabase/auth/compare/v2.155.2...v2.155.3) (2024-07-12) + + +### Bug Fixes + +* serialize jwt as string ([#1657](https://github.com/supabase/auth/issues/1657)) ([98d8324](https://github.com/supabase/auth/commit/98d83245e40d606438eb0afdbf474276179fd91d)) + +## [2.155.2](https://github.com/supabase/auth/compare/v2.155.1...v2.155.2) (2024-07-12) + + +### Bug Fixes + +* improve session error logging ([#1655](https://github.com/supabase/auth/issues/1655)) ([5a6793e](https://github.com/supabase/auth/commit/5a6793ee8fce7a089750fe10b3b63bb0a19d6d21)) +* omit empty string from name & use case-insensitive equality for comparing SAML attributes ([#1654](https://github.com/supabase/auth/issues/1654)) ([bf5381a](https://github.com/supabase/auth/commit/bf5381a6b1c686955dc4e39fe5fb806ffd309563)) +* set rate limit log level to warn ([#1652](https://github.com/supabase/auth/issues/1652)) ([10ca9c8](https://github.com/supabase/auth/commit/10ca9c806e4b67a371897f1b3f93c515764c4240)) + +## [2.155.1](https://github.com/supabase/auth/compare/v2.155.0...v2.155.1) (2024-07-04) + + +### Bug Fixes + +* apply mailer autoconfirm config to update user email ([#1646](https://github.com/supabase/auth/issues/1646)) ([a518505](https://github.com/supabase/auth/commit/a5185058e72509b0781e0eb59910ecdbb8676fee)) +* check for empty aud string ([#1649](https://github.com/supabase/auth/issues/1649)) ([42c1d45](https://github.com/supabase/auth/commit/42c1d4526b98203664d4a22c23014ecd0b4951f9)) +* return proper error if sms rate limit is exceeded ([#1647](https://github.com/supabase/auth/issues/1647)) ([3c8d765](https://github.com/supabase/auth/commit/3c8d7656431ac4b2e80726b7c37adb8f0c778495)) + +## [2.155.0](https://github.com/supabase/auth/compare/v2.154.2...v2.155.0) (2024-07-03) + + +### Features + +* add `password_hash` and `id` fields to admin create user ([#1641](https://github.com/supabase/auth/issues/1641)) ([20d59f1](https://github.com/supabase/auth/commit/20d59f10b601577683d05bcd7d2128ff4bc462a0)) + + +### Bug Fixes + +* improve mfa verify logs ([#1635](https://github.com/supabase/auth/issues/1635)) ([d8b47f9](https://github.com/supabase/auth/commit/d8b47f9d3f0dc8f97ad1de49e45f452ebc726481)) +* invited users should have a temporary password generated ([#1644](https://github.com/supabase/auth/issues/1644)) ([3f70d9d](https://github.com/supabase/auth/commit/3f70d9d8974d0e9c437c51e1312ad17ce9056ec9)) +* upgrade golang-jwt to v5 ([#1639](https://github.com/supabase/auth/issues/1639)) ([2cb97f0](https://github.com/supabase/auth/commit/2cb97f080fa4695766985cc4792d09476534be68)) +* use pointer for `user.EncryptedPassword` ([#1637](https://github.com/supabase/auth/issues/1637)) ([bbecbd6](https://github.com/supabase/auth/commit/bbecbd61a46b0c528b1191f48d51f166c06f4b16)) + +## [2.154.2](https://github.com/supabase/auth/compare/v2.154.1...v2.154.2) (2024-06-24) + + +### Bug Fixes + +* publish to ghcr.io/supabase/auth ([#1626](https://github.com/supabase/auth/issues/1626)) ([930aa3e](https://github.com/supabase/auth/commit/930aa3edb633823d4510c2aff675672df06f1211)), closes [#1625](https://github.com/supabase/auth/issues/1625) +* revert define search path in auth functions ([#1634](https://github.com/supabase/auth/issues/1634)) ([155e87e](https://github.com/supabase/auth/commit/155e87ef8129366d665968f64d1fc66676d07e16)) +* update MaxFrequency error message to reflect number of seconds ([#1540](https://github.com/supabase/auth/issues/1540)) ([e81c25d](https://github.com/supabase/auth/commit/e81c25d19551fdebfc5197d96bc220ddb0f8227b)) + +## [2.154.1](https://github.com/supabase/auth/compare/v2.154.0...v2.154.1) (2024-06-17) + + +### Bug Fixes + +* add ip based limiter ([#1622](https://github.com/supabase/auth/issues/1622)) ([06464c0](https://github.com/supabase/auth/commit/06464c013571253d1f18f7ae5e840826c4bd84a7)) +* admin user update should update is_anonymous field ([#1623](https://github.com/supabase/auth/issues/1623)) ([f5c6fcd](https://github.com/supabase/auth/commit/f5c6fcd9c3fee0f793f96880a8caebc5b5cb0916)) + +## [2.154.0](https://github.com/supabase/auth/compare/v2.153.0...v2.154.0) (2024-06-12) + + +### Features + +* add max length check for email ([#1508](https://github.com/supabase/auth/issues/1508)) ([f9c13c0](https://github.com/supabase/auth/commit/f9c13c0ad5c556bede49d3e0f6e5f58ca26161c3)) +* add support for Slack OAuth V2 ([#1591](https://github.com/supabase/auth/issues/1591)) ([bb99251](https://github.com/supabase/auth/commit/bb992519cdf7578dc02cd7de55e2e6aa09b4c0f3)) +* encrypt sensitive columns ([#1593](https://github.com/supabase/auth/issues/1593)) ([e4a4758](https://github.com/supabase/auth/commit/e4a475820b2dc1f985bd37df15a8ab9e781626f5)) +* upgrade otel to v1.26 ([#1585](https://github.com/supabase/auth/issues/1585)) ([cdd13ad](https://github.com/supabase/auth/commit/cdd13adec02eb0c9401bc55a2915c1005d50dea1)) +* use largest avatar from spotify instead ([#1210](https://github.com/supabase/auth/issues/1210)) ([4f9994b](https://github.com/supabase/auth/commit/4f9994bf792c3887f2f45910b11a9c19ee3a896b)), closes [#1209](https://github.com/supabase/auth/issues/1209) + + +### Bug Fixes + +* define search path in auth functions ([#1616](https://github.com/supabase/auth/issues/1616)) ([357bda2](https://github.com/supabase/auth/commit/357bda23cb2abd12748df80a9d27288aa548534d)) +* enable rls & update grants for auth tables ([#1617](https://github.com/supabase/auth/issues/1617)) ([28967aa](https://github.com/supabase/auth/commit/28967aa4b5db2363cc581c9da0d64e974eb7b64c)) + +## [2.153.0](https://github.com/supabase/auth/compare/v2.152.0...v2.153.0) (2024-06-04) + + +### Features + +* add SAML specific external URL config ([#1599](https://github.com/supabase/auth/issues/1599)) ([b352719](https://github.com/supabase/auth/commit/b3527190560381fafe9ba2fae4adc3b73703024a)) +* add support for verifying argon2i and argon2id passwords ([#1597](https://github.com/supabase/auth/issues/1597)) ([55409f7](https://github.com/supabase/auth/commit/55409f797bea55068a3fafdddd6cfdb78feba1b4)) +* make the email client explicity set the format to be HTML ([#1149](https://github.com/supabase/auth/issues/1149)) ([53e223a](https://github.com/supabase/auth/commit/53e223abdf29f4abcad13f99baf00daedcb00c3f)) + + +### Bug Fixes + +* call write header in write if not written ([#1598](https://github.com/supabase/auth/issues/1598)) ([0ef7eb3](https://github.com/supabase/auth/commit/0ef7eb30619d4c365e06a94a79b9cb0333d792da)) +* deadlock issue with timeout middleware write ([#1595](https://github.com/supabase/auth/issues/1595)) ([6c9fbd4](https://github.com/supabase/auth/commit/6c9fbd4bd5623c729906fca7857ab508166a3056)) +* improve token OIDC logging ([#1606](https://github.com/supabase/auth/issues/1606)) ([5262683](https://github.com/supabase/auth/commit/526268311844467664e89c8329e5aaee817dbbaf)) +* update contributing to use v1.22 ([#1609](https://github.com/supabase/auth/issues/1609)) ([5894d9e](https://github.com/supabase/auth/commit/5894d9e41e7681512a9904ad47082a705e948c98)) + +## [2.152.0](https://github.com/supabase/auth/compare/v2.151.0...v2.152.0) (2024-05-22) + + +### Features + +* new timeout writer implementation ([#1584](https://github.com/supabase/auth/issues/1584)) ([72614a1](https://github.com/supabase/auth/commit/72614a1fce27888f294772b512f8e31c55a36d87)) +* remove legacy lookup in users for one_time_tokens (phase II) ([#1569](https://github.com/supabase/auth/issues/1569)) ([39ca026](https://github.com/supabase/auth/commit/39ca026035f6c61d206d31772c661b326c2a424c)) +* update chi version ([#1581](https://github.com/supabase/auth/issues/1581)) ([c64ae3d](https://github.com/supabase/auth/commit/c64ae3dd775e8fb3022239252c31b4ee73893237)) +* update openapi spec with identity and is_anonymous fields ([#1573](https://github.com/supabase/auth/issues/1573)) ([86a79df](https://github.com/supabase/auth/commit/86a79df9ecfcf09fda0b8e07afbc41154fbb7d9d)) + + +### Bug Fixes + +* improve logging structure ([#1583](https://github.com/supabase/auth/issues/1583)) ([c22fc15](https://github.com/supabase/auth/commit/c22fc15d2a8383e95a2364f383dfa7dce5f5df88)) +* sms verify should update is_anonymous field ([#1580](https://github.com/supabase/auth/issues/1580)) ([e5f98cb](https://github.com/supabase/auth/commit/e5f98cb9e24ecebb0b7dc88c495fd456cc73fcba)) +* use api_external_url domain as localname ([#1575](https://github.com/supabase/auth/issues/1575)) ([ed2b490](https://github.com/supabase/auth/commit/ed2b4907244281e4c54aaef74b1f4c8a8e3d97c9)) + +## [2.151.0](https://github.com/supabase/auth/compare/v2.150.1...v2.151.0) (2024-05-06) + + +### Features + +* refactor one-time tokens for performance ([#1558](https://github.com/supabase/auth/issues/1558)) ([d1cf8d9](https://github.com/supabase/auth/commit/d1cf8d9096e9183d7772b73031de8ecbd66e912b)) + + +### Bug Fixes + +* do call send sms hook when SMS autoconfirm is enabled ([#1562](https://github.com/supabase/auth/issues/1562)) ([bfe4d98](https://github.com/supabase/auth/commit/bfe4d988f3768b0407526bcc7979fb21d8cbebb3)) +* format test otps ([#1567](https://github.com/supabase/auth/issues/1567)) ([434a59a](https://github.com/supabase/auth/commit/434a59ae387c35fd6629ec7c674d439537e344e5)) +* log final writer error instead of handling ([#1564](https://github.com/supabase/auth/issues/1564)) ([170bd66](https://github.com/supabase/auth/commit/170bd6615405afc852c7107f7358dfc837bad737)) + +## [2.150.1](https://github.com/supabase/auth/compare/v2.150.0...v2.150.1) (2024-04-28) + + +### Bug Fixes + +* add db conn max idle time setting ([#1555](https://github.com/supabase/auth/issues/1555)) ([2caa7b4](https://github.com/supabase/auth/commit/2caa7b4d75d2ff54af20f3e7a30a8eeec8cbcda9)) + +## [2.150.0](https://github.com/supabase/auth/compare/v2.149.0...v2.150.0) (2024-04-25) + + +### Features + +* add support for Azure CIAM login ([#1541](https://github.com/supabase/auth/issues/1541)) ([1cb4f96](https://github.com/supabase/auth/commit/1cb4f96bdc7ef3ef995781b4cf3c4364663a2bf3)) +* add timeout middleware ([#1529](https://github.com/supabase/auth/issues/1529)) ([f96ff31](https://github.com/supabase/auth/commit/f96ff31040b28e3a7373b4fd41b7334eda1b413e)) +* allow for postgres and http functions on each extensibility point ([#1528](https://github.com/supabase/auth/issues/1528)) ([348a1da](https://github.com/supabase/auth/commit/348a1daee24f6e44b14c018830b748e46d34b4c2)) +* merge provider metadata on link account ([#1552](https://github.com/supabase/auth/issues/1552)) ([bd8b5c4](https://github.com/supabase/auth/commit/bd8b5c41dd544575e1a52ccf1ef3f0fdee67458c)) +* send over user in SendSMS Hook instead of UserID ([#1551](https://github.com/supabase/auth/issues/1551)) ([d4d743c](https://github.com/supabase/auth/commit/d4d743c2ae9490e1b3249387e3b0d60df6913c68)) + + +### Bug Fixes + +* return error if session id does not exist ([#1538](https://github.com/supabase/auth/issues/1538)) ([91e9eca](https://github.com/supabase/auth/commit/91e9ecabe33a1c022f8e82a6050c22a7ca42de48)) + +## [2.149.0](https://github.com/supabase/auth/compare/v2.148.0...v2.149.0) (2024-04-15) + + +### Features + +* refactor generate accesss token to take in request ([#1531](https://github.com/supabase/auth/issues/1531)) ([e4f2b59](https://github.com/supabase/auth/commit/e4f2b59e8e1f8158b6461a384349f1a32cc1bf9a)) + + +### Bug Fixes + +* linkedin_oidc provider error ([#1534](https://github.com/supabase/auth/issues/1534)) ([4f5e8e5](https://github.com/supabase/auth/commit/4f5e8e5120531e5a103fbdda91b51cabcb4e1a8c)) +* revert patch for linkedin_oidc provider error ([#1535](https://github.com/supabase/auth/issues/1535)) ([58ef4af](https://github.com/supabase/auth/commit/58ef4af0b4224b78cd9e59428788d16a8d31e562)) +* update linkedin issuer url ([#1536](https://github.com/supabase/auth/issues/1536)) ([10d6d8b](https://github.com/supabase/auth/commit/10d6d8b1eafa504da2b2a351d1f64a3a832ab1b9)) + +## [2.148.0](https://github.com/supabase/auth/compare/v2.147.1...v2.148.0) (2024-04-10) + + +### Features + +* add array attribute mapping for SAML ([#1526](https://github.com/supabase/auth/issues/1526)) ([7326285](https://github.com/supabase/auth/commit/7326285c8af5c42e5c0c2d729ab224cf33ac3a1f)) + +## [2.147.1](https://github.com/supabase/auth/compare/v2.147.0...v2.147.1) (2024-04-09) + + +### Bug Fixes + +* add validation and proper decoding on send email hook ([#1520](https://github.com/supabase/auth/issues/1520)) ([e19e762](https://github.com/supabase/auth/commit/e19e762e3e29729a1d1164c65461427822cc87f1)) +* remove deprecated LogoutAllRefreshTokens ([#1519](https://github.com/supabase/auth/issues/1519)) ([35533ea](https://github.com/supabase/auth/commit/35533ea100669559e1209ecc7b091db3657234d9)) + +## [2.147.0](https://github.com/supabase/auth/compare/v2.146.0...v2.147.0) (2024-04-05) + + +### Features + +* add send email Hook ([#1512](https://github.com/supabase/auth/issues/1512)) ([cf42e02](https://github.com/supabase/auth/commit/cf42e02ec63779f52b1652a7413f64994964c82d)) + +## [2.146.0](https://github.com/supabase/auth/compare/v2.145.0...v2.146.0) (2024-04-03) + + +### Features + +* add custom sms hook ([#1474](https://github.com/supabase/auth/issues/1474)) ([0f6b29a](https://github.com/supabase/auth/commit/0f6b29a46f1dcbf92aa1f7cb702f42e7640f5f93)) +* forbid generating an access token without a session ([#1504](https://github.com/supabase/auth/issues/1504)) ([795e93d](https://github.com/supabase/auth/commit/795e93d0afbe94bcd78489a3319a970b7bf8e8bc)) + + +### Bug Fixes + +* add cleanup statement for anonymous users ([#1497](https://github.com/supabase/auth/issues/1497)) ([cf2372a](https://github.com/supabase/auth/commit/cf2372a177796b829b72454e7491ce768bf5a42f)) +* generate signup link should not error ([#1514](https://github.com/supabase/auth/issues/1514)) ([4fc3881](https://github.com/supabase/auth/commit/4fc388186ac7e7a9a32ca9b963a83d6ac2eb7603)) +* move all EmailActionTypes to mailer package ([#1510](https://github.com/supabase/auth/issues/1510)) ([765db08](https://github.com/supabase/auth/commit/765db08582669a1b7f054217fa8f0ed45804c0b5)) +* refactor mfa and aal update methods ([#1503](https://github.com/supabase/auth/issues/1503)) ([31a5854](https://github.com/supabase/auth/commit/31a585429bf248aa919d94c82c7c9e0c1c695461)) +* rename from CustomSMSProvider to SendSMS ([#1513](https://github.com/supabase/auth/issues/1513)) ([c0bc37b](https://github.com/supabase/auth/commit/c0bc37b44effaebb62ba85102f072db07fe57e48)) + +## [2.145.0](https://github.com/supabase/gotrue/compare/v2.144.0...v2.145.0) (2024-03-26) + + +### Features + +* add error codes ([#1377](https://github.com/supabase/gotrue/issues/1377)) ([e4beea1](https://github.com/supabase/gotrue/commit/e4beea1cdb80544b0581f1882696a698fdf64938)) +* add kakao OIDC ([#1381](https://github.com/supabase/gotrue/issues/1381)) ([b5566e7](https://github.com/supabase/gotrue/commit/b5566e7ac001cc9f2bac128de0fcb908caf3a5ed)) +* clean up expired factors ([#1371](https://github.com/supabase/gotrue/issues/1371)) ([5c94207](https://github.com/supabase/gotrue/commit/5c9420743a9aef0675f823c30aa4525b4933836e)) +* configurable NameID format for SAML provider ([#1481](https://github.com/supabase/gotrue/issues/1481)) ([ef405d8](https://github.com/supabase/gotrue/commit/ef405d89e69e008640f275bc37f8ec02ad32da40)) +* HTTP Hook - Add custom envconfig decoding for HTTP Hook Secrets ([#1467](https://github.com/supabase/gotrue/issues/1467)) ([5b24c4e](https://github.com/supabase/gotrue/commit/5b24c4eb05b2b52c4177d5f41cba30cb68495c8c)) +* refactor PKCE FlowState to reduce duplicate code ([#1446](https://github.com/supabase/gotrue/issues/1446)) ([b8d0337](https://github.com/supabase/gotrue/commit/b8d0337922c6712380f6dc74f7eac9fb71b1ae48)) + + +### Bug Fixes + +* add http support for https hooks on localhost ([#1484](https://github.com/supabase/gotrue/issues/1484)) ([5c04104](https://github.com/supabase/gotrue/commit/5c04104bf77a9c2db46d009764ec3ec3e484fc09)) +* cleanup panics due to bad inactivity timeout code ([#1471](https://github.com/supabase/gotrue/issues/1471)) ([548edf8](https://github.com/supabase/gotrue/commit/548edf898161c9ba9a136fc99ec2d52a8ba1f856)) +* **docs:** remove bracket on file name for broken link ([#1493](https://github.com/supabase/gotrue/issues/1493)) ([96f7a68](https://github.com/supabase/gotrue/commit/96f7a68a5479825e31106c2f55f82d5b2c007c0f)) +* impose expiry on auth code instead of magic link ([#1440](https://github.com/supabase/gotrue/issues/1440)) ([35aeaf1](https://github.com/supabase/gotrue/commit/35aeaf1b60dd27a22662a6d1955d60cc907b55dd)) +* invalidate email, phone OTPs on password change ([#1489](https://github.com/supabase/gotrue/issues/1489)) ([960a4f9](https://github.com/supabase/gotrue/commit/960a4f94f5500e33a0ec2f6afe0380bbc9562500)) +* move creation of flow state into function ([#1470](https://github.com/supabase/gotrue/issues/1470)) ([4392a08](https://github.com/supabase/gotrue/commit/4392a08d68d18828005d11382730117a7b143635)) +* prevent user email side-channel leak on verify ([#1472](https://github.com/supabase/gotrue/issues/1472)) ([311cde8](https://github.com/supabase/gotrue/commit/311cde8d1e82f823ae26a341e068034d60273864)) +* refactor email sending functions ([#1495](https://github.com/supabase/gotrue/issues/1495)) ([285c290](https://github.com/supabase/gotrue/commit/285c290adf231fea7ca1dff954491dc427cf18e2)) +* refactor factor_test to centralize setup ([#1473](https://github.com/supabase/gotrue/issues/1473)) ([c86007e](https://github.com/supabase/gotrue/commit/c86007e59684334b5e8c2285c36094b6eec89442)) +* refactor mfa challenge and tests ([#1469](https://github.com/supabase/gotrue/issues/1469)) ([6c76f21](https://github.com/supabase/gotrue/commit/6c76f21cee5dbef0562c37df6a546939affb2f8d)) +* Resend SMS when duplicate SMS sign ups are made ([#1490](https://github.com/supabase/gotrue/issues/1490)) ([73240a0](https://github.com/supabase/gotrue/commit/73240a0b096977703e3c7d24a224b5641ce47c81)) +* unlink identity bugs ([#1475](https://github.com/supabase/gotrue/issues/1475)) ([73e8d87](https://github.com/supabase/gotrue/commit/73e8d8742de3575b3165a707b5d2f486b2598d9d)) + +## [2.144.0](https://github.com/supabase/gotrue/compare/v2.143.0...v2.144.0) (2024-03-04) + + +### Features + +* add configuration for custom sms sender hook ([#1428](https://github.com/supabase/gotrue/issues/1428)) ([1ea56b6](https://github.com/supabase/gotrue/commit/1ea56b62d47edb0766d9e445406ecb43d387d920)) +* anonymous sign-ins ([#1460](https://github.com/supabase/gotrue/issues/1460)) ([130df16](https://github.com/supabase/gotrue/commit/130df165270c69c8e28aaa1b9421342f997c1ff3)) +* clean up test setup in MFA tests ([#1452](https://github.com/supabase/gotrue/issues/1452)) ([7185af8](https://github.com/supabase/gotrue/commit/7185af8de4a269cdde2629054d222333d3522ebe)) +* pass transaction to `invokeHook`, fixing pool exhaustion ([#1465](https://github.com/supabase/gotrue/issues/1465)) ([b536d36](https://github.com/supabase/gotrue/commit/b536d368f35adb31f937169e3f093d28352fa7be)) +* refactor resource owner password grant ([#1443](https://github.com/supabase/gotrue/issues/1443)) ([e63ad6f](https://github.com/supabase/gotrue/commit/e63ad6ff0f67d9a83456918a972ecb5109125628)) +* use dummy instance id to improve performance on refresh token queries ([#1454](https://github.com/supabase/gotrue/issues/1454)) ([656474e](https://github.com/supabase/gotrue/commit/656474e1b9ff3d5129190943e8c48e456625afe5)) + + +### Bug Fixes + +* expose `provider` under `amr` in access token ([#1456](https://github.com/supabase/gotrue/issues/1456)) ([e9f38e7](https://github.com/supabase/gotrue/commit/e9f38e76d8a7b93c5c2bb0de918a9b156155f018)) +* improve MFA QR Code resilience so as to support providers like 1Password ([#1455](https://github.com/supabase/gotrue/issues/1455)) ([6522780](https://github.com/supabase/gotrue/commit/652278046c9dd92f5cecd778735b058ef3fb41c7)) +* refactor request params to use generics ([#1464](https://github.com/supabase/gotrue/issues/1464)) ([e1cdf5c](https://github.com/supabase/gotrue/commit/e1cdf5c4b5c1bf467094f4bdcaa2e42a5cc51c20)) +* revert refactor resource owner password grant ([#1466](https://github.com/supabase/gotrue/issues/1466)) ([fa21244](https://github.com/supabase/gotrue/commit/fa21244fa929709470c2e1fc4092a9ce947399e7)) +* update file name so migration to Drop IP Address is applied ([#1447](https://github.com/supabase/gotrue/issues/1447)) ([f29e89d](https://github.com/supabase/gotrue/commit/f29e89d7d2c48ee8fd5bf8279a7fa3db0ad4d842)) + +## [2.143.0](https://github.com/supabase/gotrue/compare/v2.142.0...v2.143.0) (2024-02-19) + + +### Features + +* calculate aal without transaction ([#1437](https://github.com/supabase/gotrue/issues/1437)) ([8dae661](https://github.com/supabase/gotrue/commit/8dae6614f1a2b58819f94894cef01e9f99117769)) + + +### Bug Fixes + +* deprecate hooks ([#1421](https://github.com/supabase/gotrue/issues/1421)) ([effef1b](https://github.com/supabase/gotrue/commit/effef1b6ecc448b7927eff23df8d5b509cf16b5c)) +* error should be an IsNotFoundError ([#1432](https://github.com/supabase/gotrue/issues/1432)) ([7f40047](https://github.com/supabase/gotrue/commit/7f40047aec3577d876602444b1d88078b2237d66)) +* populate password verification attempt hook ([#1436](https://github.com/supabase/gotrue/issues/1436)) ([f974bdb](https://github.com/supabase/gotrue/commit/f974bdb58340395955ca27bdd26d57062433ece9)) +* restrict mfa enrollment to aal2 if verified factors are present ([#1439](https://github.com/supabase/gotrue/issues/1439)) ([7e10d45](https://github.com/supabase/gotrue/commit/7e10d45e54010d38677f4c3f2f224127688eb9a2)) +* update phone if autoconfirm is enabled ([#1431](https://github.com/supabase/gotrue/issues/1431)) ([95db770](https://github.com/supabase/gotrue/commit/95db770c5d2ecca4a1e960a8cb28ded37cccc100)) +* use email change email in identity ([#1429](https://github.com/supabase/gotrue/issues/1429)) ([4d3b9b8](https://github.com/supabase/gotrue/commit/4d3b9b8841b1a5fa8f3244825153cc81a73ba300)) + +## [2.142.0](https://github.com/supabase/gotrue/compare/v2.141.0...v2.142.0) (2024-02-14) + + +### Features + +* alter tag to use raw ([#1427](https://github.com/supabase/gotrue/issues/1427)) ([53cfe5d](https://github.com/supabase/gotrue/commit/53cfe5de57d4b5ab6e8e2915493856ecd96f4ede)) +* update README.md to trigger release ([#1425](https://github.com/supabase/gotrue/issues/1425)) ([91e0e24](https://github.com/supabase/gotrue/commit/91e0e245f5957ebce13370f79fd4a6be8108ed80)) + +## [2.141.0](https://github.com/supabase/gotrue/compare/v2.140.0...v2.141.0) (2024-02-13) + + +### Features + +* drop sha hash tag ([#1422](https://github.com/supabase/gotrue/issues/1422)) ([76853ce](https://github.com/supabase/gotrue/commit/76853ce6d45064de5608acc8100c67a8337ba791)) +* prefix release with v ([#1424](https://github.com/supabase/gotrue/issues/1424)) ([9d398cd](https://github.com/supabase/gotrue/commit/9d398cd75fca01fb848aa88b4f545552e8b5751a)) + +## [2.140.0](https://github.com/supabase/gotrue/compare/v2.139.2...v2.140.0) (2024-02-13) + + +### Features + +* deprecate existing webhook implementation ([#1417](https://github.com/supabase/gotrue/issues/1417)) ([5301e48](https://github.com/supabase/gotrue/commit/5301e481b0c7278c18b4578a5b1aa8d2256c2f5d)) +* update publish.yml checkout repository so there is access to Dockerfile ([#1419](https://github.com/supabase/gotrue/issues/1419)) ([7cce351](https://github.com/supabase/gotrue/commit/7cce3518e8c9f1f3f93e4f6a0658ee08771c4f1c)) + +## [2.139.2](https://github.com/supabase/gotrue/compare/v2.139.1...v2.139.2) (2024-02-08) + + +### Bug Fixes + +* improve perf in account linking ([#1394](https://github.com/supabase/gotrue/issues/1394)) ([8eedb95](https://github.com/supabase/gotrue/commit/8eedb95dbaa310aac464645ec91d6a374813ab89)) +* OIDC provider validation log message ([#1380](https://github.com/supabase/gotrue/issues/1380)) ([27e6b1f](https://github.com/supabase/gotrue/commit/27e6b1f9a4394c5c4f8dff9a8b5529db1fc67af9)) +* only create or update the email / phone identity after it's been verified ([#1403](https://github.com/supabase/gotrue/issues/1403)) ([2d20729](https://github.com/supabase/gotrue/commit/2d207296ec22dd6c003c89626d255e35441fd52d)) +* only create or update the email / phone identity after it's been verified (again) ([#1409](https://github.com/supabase/gotrue/issues/1409)) ([bc6a5b8](https://github.com/supabase/gotrue/commit/bc6a5b884b43fe6b8cb924d3f79999fe5bfe7c5f)) +* unmarshal is_private_email correctly ([#1402](https://github.com/supabase/gotrue/issues/1402)) ([47df151](https://github.com/supabase/gotrue/commit/47df15113ce8d86666c0aba3854954c24fe39f7f)) +* use `pattern` for semver docker image tags ([#1411](https://github.com/supabase/gotrue/issues/1411)) ([14a3aeb](https://github.com/supabase/gotrue/commit/14a3aeb6c3f46c8d38d98cc840112dfd0278eeda)) + + +### Reverts + +* "fix: only create or update the email / phone identity after i… ([#1407](https://github.com/supabase/gotrue/issues/1407)) ([ff86849](https://github.com/supabase/gotrue/commit/ff868493169a0d9ac18b66058a735197b1df5b9b)) diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..cb9b3cac9 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @supabase/auth diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3003d492c..22a1add3f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,40 +1,78 @@ # CONTRIBUTING We would love to have contributions from each and every one of you in the community be it big or small and you are the ones who motivate us to do better than what we do today. -Docs aren't perfect and so we're here to help. If you're stuck on setup for more than thirty minutes please feel free to reach out on the `#gotrue` channel on [the Supabase discord](https://discord.gg/tNy8sVEf) ## Code Of Conduct -Please help us keep all our projects open and inclusive. Kindly follow our [Code of Conduct](<(CODE_OF_CONDUCT.md)>) to keep the ecosystem healthy and friendly for all. +Please help us keep all our projects open and inclusive. Kindly follow our [Code of Conduct](CODE_OF_CONDUCT.md) to keep the ecosystem healthy and friendly for all. + +## Quick Start + +Auth has a development container setup that makes it easy to get started contributing. This setup only requires that [Docker](https://www.docker.com/get-started) is setup on your system. The development container setup includes a PostgreSQL container with migrations already applied and a container running GoTrue that will perform a hot reload when changes to the source code are detected. + +If you would like to run Auth locally or learn more about what these containers are doing for you, continue reading the [Setup and Tooling](#setup-and-tooling) section below. Otherwise, you can skip ahead to the [How To Verify that GoTrue is Available](#how-to-verify-that-auth-is-available) section to learn about working with and developing GoTrue. + +Before using the containers, you will need to make sure an `.env.docker` file exists by making a copy of `example.docker.env` and configuring it for your needs. The set of env vars in `example.docker.env` only contain the necessary env vars for auth to start in a docker environment. For the full list of env vars, please refer to `example.env` and copy over the necessary ones into your `.env.docker` file. + +The following are some basic commands. A full and up to date list of commands can be found in the project's `Makefile` or by running `make help`. + +### Starting the containers + +Start the containers as described above in an attached state with log output. + +```bash +make dev +``` + +### Running tests in the containers + +Start the containers with a fresh database and run the project's tests. + +```bash +make docker-test +``` + +### Removing the containers + +Remove both containers and their volumes. This removes any data associated with the containers. + +```bash +make docker-clean +``` + +### Rebuild the containers + +Fully rebuild the containers without using any cached layers. + +```bash +make docker-build +``` ## Setup and Tooling -GoTrue -- as the name implies -- is a user registration and authentication API developed in [Go](https://go.dev). +Auth -- as the name implies -- is a user registration and authentication API developed in [Go](https://go.dev). It connects to a [PostgreSQL](https://www.postgresql.org) database in order to store authentication data, [Soda CLI](https://gobuffalo.io/en/docs/db/toolbox) to manage database schema and migrations, and runs inside a [Docker](https://www.docker.com/get-started) container. -Therefore, to contribute to GoTrue you will need to install these tools. +Therefore, to contribute to Auth you will need to install these tools. ### Install Tools -- Install [Go](https://go.dev) 1.16 +- Install [Go](https://go.dev) 1.22 -```terminal -# Via Homebrew on OSX -brew install go@1.16 - -# Set the GOPATH environment variable in the ~/.zshrc file -export GOPATH="$HOME/go" +```zsh +# Via Homebrew on macOS +brew install go@1.22 -# Add the GOPATH to your path -echo 'export PATH="$GOPATH/bin:$PATH"' >> ~/.zshrc +# Set the environment variable in the ~/.zshrc file +echo 'export PATH="/opt/homebrew/opt/go@1.22/bin:$PATH"' >> ~/.zshrc ``` - Install [Docker](https://www.docker.com/get-started) -```terminal -# Via Homebrew on OSX +```zsh +# Via Homebrew on macOS brew install docker ``` @@ -42,103 +80,85 @@ Or, if you prefer, download [Docker Desktop](https://www.docker.com/get-started) - Install [Soda CLI](https://gobuffalo.io/en/docs/db/toolbox) +```zsh +# Via Homebrew on macOS +brew install gobuffalo/tap/pop +``` + If you are on macOS Catalina you may [run into issues installing Soda with Brew](https://github.com/gobuffalo/homebrew-tap/issues/5). Do check your `GOPATH` and run `go build -o /bin/soda github.com/gobuffalo/pop/soda` to resolve. -``` -go install github.com/gobuffalo/pop/soda@latest -``` - -- Clone the GoTrue [repository](https://github.com/supabase/gotrue) +- Clone the Auth [repository](https://github.com/supabase/auth) -``` -git clone https://github.com/supabase/gotrue +```zsh +git clone https://github.com/supabase/auth ``` -### Install GoTrue +### Install Auth To begin installation, be sure to start from the root directory. -- `cd gotrue` +- `cd auth` To complete installation, you will: - Install the PostgreSQL Docker image - Create the DB Schema and Migrations - Setup a local `.env` for environment variables -- Compile GoTrue -- Run the GoTrue binary executable +- Compile Auth +- Run the Auth binary executable #### Installation Steps 1. Start Docker 2. To install the PostgreSQL Docker image, run: -``` -./hack/postgresd.sh -``` - -You may see a message like: - -``` -Unable to find image 'postgres:13' locally -``` +```zsh +# Builds the postgres image +docker-compose -f docker-compose-dev.yml build postgres -And then - -``` -Pulling from library/postgres +# Runs the postgres container +docker-compose -f docker-compose-dev.yml up postgres ``` -as Docker installs the image: - -``` -Unable to find image 'postgres:13' locally -13: Pulling from library/postgres -968621624b32: Pull complete -9ef9c0761899: Pull complete -effb6e89256d: Pull complete -e19a7fe239e0: Pull complete -7f97626b93ac: Pull complete -ecc35a9a2c7c: Pull complete -b749e660435b: Pull complete -457ea4f6253a: Pull complete -722af21d2ec3: Pull complete -899eee526623: Pull complete -746f304547aa: Pull complete -2d4dfc6819e6: Pull complete -c99864ddd548: Pull complete -Digest: sha256:3c6d1cef78fe0c84a79c76f0907aed29895dff661fecd45103f7afe2a055078e -Status: Downloaded newer image for postgres:13 -f709b97d83fddc3b099e4f2ddc4cb2fbf68052e7a8093332bec57672f38cfa36 -``` - -You should then see in Docker that `gotrue_postgresql` is running on `port: 5432`. +You should then see in Docker that `auth_postgresql` is running on `port: 5432`. > **Important** If you happen to already have a local running instance of Postgres running on the port `5432` because you -> may have installed via [homebrew on OSX](https://formulae.brew.sh/formula/postgresql) then be certain to stop the process using: +> may have installed via [homebrew on macOS](https://formulae.brew.sh/formula/postgresql) then be certain to stop the process using: > > - `brew services stop postgresql` > > If you need to run the test environment on another port, you will need to modify several configuration files to use a different custom port. -3. Next compile the GoTrue binary: +3. Next compile the Auth binary: + +When you fork a repository, GitHub does not automatically copy all the tags (tags are not included by default). To ensure the correct tag is set before building the binary, you need to fetch the tags from the upstream repository and push them to your fork. Follow these steps: +```zsh +# Fetch the tags from the upstream repository +git fetch upstream --tags + +# Push the tags to your fork +git push origin --tags ``` + +Then build the binary by running: + +```zsh make build ``` 4. To setup the database schema via Soda, run: -``` +```zsh make migrate_test ``` -You should see log messages that indicate that the GoTrue migrations were applied successfully: +You should see log messages that indicate that the Auth migrations were applied successfully: ```terminal -INFO[0000] GoTrue migrations applied successfully +INFO[0000] Auth migrations applied successfully DEBU[0000] after status [POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210710035447"] [POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210722035447"] @@ -155,40 +175,41 @@ Version Name Status That lists each migration that was applied. Note: there may be more migrations than those listed. -4. Create a `.env` file in the root of the project and copy the following config in [example.env](example.env) -5. In order to have GoTrue connect to your PostgreSQL database running in Docker, it is important to set a connection string like: +4. Create a `.env` file in the root of the project and copy the following config in [example.env](example.env). Set the values to GOTRUE_SMS_TEST_OTP_VALID_UNTIL in the `.env` file. + +5. In order to have Auth connect to your PostgreSQL database running in Docker, it is important to set a connection string like: ``` DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" ``` -> Important: GoTrue requires a set of SMTP credentials to run, you can generate your own SMTP credentials via an SMTP provider such as AWS SES, SendGrid, MailChimp, SendInBlue or any other SMTP providers. +> Important: Auth requires a set of SMTP credentials to run, you can generate your own SMTP credentials via an SMTP provider such as AWS SES, SendGrid, MailChimp, SendInBlue or any other SMTP providers. -6. Then finally Start GoTrue -7. Verify that GoTrue is Available +6. Then finally Start Auth +7. Verify that Auth is Available -### Starting GoTrue +### Starting Auth -Start GoTrue by running the executable: +Start Auth by running the executable: -``` -./gotrue +```zsh +./auth ``` -This command will re-run migrations and then indicate that GoTrue has started: +This command will re-run migrations and then indicate that Auth has started: -``` -INFO[0000] GoTrue API started on: localhost:9999 +```zsh +INFO[0000] Auth API started on: localhost:9999 ``` -### How To Verify that GoTrue is Available +### How To Verify that Auth is Available -To test that your GoTrue is up and available, you can query the `health` endpoint at `http://localhost:9999/health`. You should see a response similar to: +To test that your Auth is up and available, you can query the `health` endpoint at `http://localhost:9999/health`. You should see a response similar to: ```json { - "description": "GoTrue is a user registration and authentication API", - "name": "GoTrue", + "description": "Auth is a user registration and authentication API", + "name": "Auth", "version": "" } ``` @@ -208,6 +229,7 @@ To see the current settings, make a request to `http://localhost:9999/settings` "facebook": false, "spotify": false, "slack": false, + "slack_oidc": false, "twitch": true, "twitter": false, "email": true, @@ -379,19 +401,19 @@ The response from `/admin/users` should return all users: If you need to run any new migrations: -``` +```zsh make migrate_test ``` ## Testing -Currently, we don't use a separate test database, so the same database created when installing GoTrue to run locally is used. +Currently, we don't use a separate test database, so the same database created when installing Auth to run locally is used. The following commands should help in setting up a database and running the tests: ```sh # Runs the database in a docker container -$ ./hack/postgresd.sh +$ docker-compose -f docker-compose-dev.yml up postgres # Applies the migrations to the database (requires soda cli) $ make migrate_test @@ -410,9 +432,9 @@ In these examples, we change the port from 5432 to 7432. > Note: This is not recommended, but if you do, please do not check in changes. ``` -// file: postgresd.sh -docker run --name gotrue_postgresql --p 7432:5432 \ 👈 set the first value to your external facing port +// file: docker-compose-dev.yml +ports: + - 7432:5432 \ 👈 set the first value to your external facing port ``` The port you customize here can them be used in the subsequent configuration: @@ -439,11 +461,16 @@ export GOTRUE_DB_DATABASE_URL="postgres://supabase_auth_admin:root@localhost:743 ## Helpful Docker Commands ``` +// file: docker-compose-dev.yml +container_name: auth_postgres +``` + +```zsh # Command line into bash on the PostgreSQL container -docker exec -it gotrue_postgresql bash +docker exec -it auth_postgres bash # Removes Container -docker container rm -f gotrue_postgresql +docker container rm -f auth_postgres # Removes volume docker volume rm postgres_data @@ -468,12 +495,14 @@ We actively welcome your pull requests. - Is there a corresponding issue created for it? If so, please include it in the PR description so we can track / refer to it. - Does your PR follow the [semantic-release commit guidelines](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#-git-commit-guidelines)? -- If the PR is a `feat`, an [RFC](https://github.com/supabase/rfcs) or a detailed description of the design implementation is required. The former (RFC) is prefered before starting on the PR. +- If the PR is a `feat`, an [RFC](https://github.com/supabase/rfcs) or a detailed description of the design implementation is required. The former (RFC) is preferred before starting on the PR. - Are the existing tests passing? - Have you written some tests for your PR? ## Guidelines for Implementing Additional OAuth Providers +> ⚠️ We won't be accepting any additional oauth / sms provider contributions for now because we intend to support these through webhooks or a generic provider in the future. + Please ensure that an end-to-end test is done for the OAuth provider implemented. An end-to-end test includes: @@ -490,5 +519,5 @@ Since implementing an additional OAuth provider consists of making api calls to ## License -By contributing to GoTrue, you agree that your contributions will be licensed +By contributing to Auth, you agree that your contributions will be licensed under its [MIT license](LICENSE). diff --git a/Dockerfile b/Dockerfile index 61ef17f13..6fd9fc300 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,28 +1,32 @@ -FROM golang:1.16-alpine as build +FROM golang:1.23.7-alpine3.20 as build ENV GO111MODULE=on ENV CGO_ENABLED=0 ENV GOOS=linux RUN apk add --no-cache make git -WORKDIR /go/src/github.com/netlify/gotrue +WORKDIR /go/src/github.com/supabase/auth # Pulling dependencies COPY ./Makefile ./go.* ./ RUN make deps # Building stuff -COPY . /go/src/github.com/netlify/gotrue -RUN make build +COPY . /go/src/github.com/supabase/auth -FROM alpine:3.7 -RUN adduser -D -u 1000 netlify +# Make sure you change the RELEASE_VERSION value before publishing an image. +RUN RELEASE_VERSION=unspecified make build + +# Always use alpine:3 so the latest version is used. This will keep CA certs more up to date. +FROM alpine:3 +RUN adduser -D -u 1000 supabase RUN apk add --no-cache ca-certificates -COPY --from=build /go/src/github.com/netlify/gotrue/gotrue /usr/local/bin/gotrue -COPY --from=build /go/src/github.com/netlify/gotrue/migrations /usr/local/etc/gotrue/migrations/ +COPY --from=build /go/src/github.com/supabase/auth/auth /usr/local/bin/auth +COPY --from=build /go/src/github.com/supabase/auth/migrations /usr/local/etc/auth/migrations/ +RUN ln -s /usr/local/bin/auth /usr/local/bin/gotrue -ENV GOTRUE_DB_MIGRATIONS_PATH /usr/local/etc/gotrue/migrations +ENV GOTRUE_DB_MIGRATIONS_PATH /usr/local/etc/auth/migrations -USER netlify -CMD ["gotrue"] +USER supabase +CMD ["auth"] diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 000000000..9f0a30421 --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,18 @@ +FROM golang:1.23.7-alpine3.20 +ENV GO111MODULE=on +ENV CGO_ENABLED=0 +ENV GOOS=linux + +RUN apk add --no-cache make git bash + +WORKDIR /go/src/github.com/supabase/auth + +# Pulling dependencies +COPY ./Makefile ./go.* ./ + +# Production dependencies +RUN make deps + +# Development dependences +RUN go get github.com/githubnemo/CompileDaemon +RUN go install github.com/githubnemo/CompileDaemon diff --git a/Dockerfile.postgres.dev b/Dockerfile.postgres.dev new file mode 100644 index 000000000..58661ef6b --- /dev/null +++ b/Dockerfile.postgres.dev @@ -0,0 +1,8 @@ +FROM postgres:15 +WORKDIR / +RUN pwd +COPY init_postgres.sh /docker-entrypoint-initdb.d/init.sh +RUN chmod +x /docker-entrypoint-initdb.d/init.sh +EXPOSE 5432 + +CMD ["postgres"] diff --git a/Makefile b/Makefile index c59f4fb4b..7bfd8ab98 100644 --- a/Makefile +++ b/Makefile @@ -1,26 +1,38 @@ -.PHONY: all build deps image lint migrate test vet -CHECK_FILES?=$$(go list ./... | grep -v /vendor/) -FLAGS?=-ldflags "-X github.com/netlify/gotrue/cmd.Version=`git describe --tags`" +.PHONY: all build deps dev-deps image migrate test vet sec format unused +CHECK_FILES?=./... + +FLAGS=-ldflags "-X github.com/supabase/auth/internal/utilities.Version=`git describe --tags`" -buildvcs=false +ifdef RELEASE_VERSION + FLAGS=-ldflags "-X github.com/supabase/auth/internal/utilities.Version=v$(RELEASE_VERSION)" -buildvcs=false +endif + +ifneq ($(shell docker compose version 2>/dev/null),) + DOCKER_COMPOSE=docker compose +else + DOCKER_COMPOSE=docker-compose +endif + +DEV_DOCKER_COMPOSE:=docker-compose-dev.yml help: ## Show this help. @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {sub("\\\\n",sprintf("\n%22c"," "), $$2);printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) -all: lint vet build ## Run the tests and build the binary. +all: vet sec static build ## Run the tests and build the binary. + +build: deps ## Build the binary. + CGO_ENABLED=0 go build $(FLAGS) + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build $(FLAGS) -o auth-arm64 -build: ## Build the binary. - go build $(FLAGS) - GOOS=linux GOARCH=arm64 go build $(FLAGS) -o gotrue-arm64 +dev-deps: ## Install developer dependencies + @go install github.com/gobuffalo/pop/soda@latest + @go install github.com/securego/gosec/v2/cmd/gosec@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest + @go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@latest + @go install github.com/nishanths/exhaustive/cmd/exhaustive@latest deps: ## Install dependencies. - @go install github.com/gobuffalo/pop/soda@latest - @go install golang.org/x/lint/golint@latest @go mod download - -image: ## Build the Docker image. - docker build . - -lint: ## Lint the code. - golint $(CHECK_FILES) + @go mod verify migrate_dev: ## Run database migrations for development. hack/migrate.sh postgres @@ -28,8 +40,54 @@ migrate_dev: ## Run database migrations for development. migrate_test: ## Run database migrations for test. hack/migrate.sh postgres -test: ## Run tests. - go test -p 1 -v $(CHECK_FILES) +test: build ## Run tests. + go test $(CHECK_FILES) -coverprofile=coverage.out -coverpkg ./... -p 1 -race -v -count=1 + ./hack/coverage.sh vet: # Vet the code go vet $(CHECK_FILES) + +sec: dev-deps # Check for security vulnerabilities + gosec -quiet -exclude-generated $(CHECK_FILES) + gosec -quiet -tests -exclude-generated -exclude=G104 $(CHECK_FILES) + +unused: dev-deps # Look for unused code + @echo "Unused code:" + staticcheck -checks U1000 $(CHECK_FILES) + + @echo + + @echo "Code used only in _test.go (do move it in those files):" + staticcheck -checks U1000 -tests=false $(CHECK_FILES) + +static: dev-deps + staticcheck ./... + exhaustive ./... + +generate: dev-deps + go generate ./... + +dev: ## Run the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up + +down: ## Shutdown the development containers + # Start postgres first and apply migrations + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down + +docker-test: ## Run the tests using the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up -d postgres + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make migrate_test" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make test" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down -v + +docker-build: ## Force a full rebuild of the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) build --no-cache + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up -d postgres + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make migrate_dev" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down + +docker-clean: ## Remove the development containers and volumes + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) rm -fsv + +format: + gofmt -s -w . diff --git a/README.md b/README.md index 04ee5d808..08be82ba4 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,171 @@ -# GoTrue - User management for APIs +# Auth - Authentication and User Management by Supabase -GoTrue is a small open-source API written in golang, that can act as a self-standing -API service for handling user registration and authentication for JAM projects. +[![Coverage Status](https://coveralls.io/repos/github/supabase/auth/badge.svg?branch=master)](https://coveralls.io/github/supabase/auth?branch=master) -It's based on OAuth2 and JWT and will handle user signup, authentication and custom -user data. +Auth is a user management and authentication server written in Go that powers +[Supabase](https://supabase.com)'s features such as: + +- Issuing JWTs +- Row Level Security with PostgREST +- User management +- Sign in with email, password, magic link, phone number +- Sign in with external providers (Google, Apple, Facebook, Discord, ...) + +It is originally based on the excellent +[GoTrue codebase by Netlify](https://github.com/netlify/gotrue), however both have diverged significantly in features and capabilities. + +If you wish to contribute to the project, please refer to the [contributing guide](/CONTRIBUTING.md). + +## Table Of Contents + +- [Quick Start](#quick-start) +- [Running in Production](#running-in-production) +- [Configuration](#configuration) +- [Endpoints](#endpoints) ## Quick Start Create a `.env` file to store your own custom env vars. See [`example.env`](example.env) -1. Start the local postgres database in a postgres container: `./hack/postgresd.sh` -2. Build the gotrue binary: `make build` . You should see an output like this: -``` -go build -ldflags "-X github.com/supabase/gotrue/cmd.Version=`git rev-parse HEAD`" -GOOS=linux GOARCH=arm64 go build -ldflags "-X github.com/supabase/gotrue/cmd.Version=`git rev-parse HEAD`" -o gotrue-arm64 +1. Start the local postgres database in a postgres container: `docker-compose -f docker-compose-dev.yml up postgres` +2. Build the auth binary: `make build` . You should see an output like this: + +```bash +go build -ldflags "-X github.com/supabase/auth/cmd.Version=`git rev-parse HEAD`" +GOOS=linux GOARCH=arm64 go build -ldflags "-X github.com/supabase/auth/cmd.Version=`git rev-parse HEAD`" -o gotrue-arm64 ``` -3. Execute the gotrue binary: `./gotrue` (if you're on x86) `./gotrue-arm64` (if you're on arm) + +3. Execute the auth binary: `./auth` + +### If you have docker installed + +Create a `.env.docker` file to store your own custom env vars. See [`example.docker.env`](example.docker.env) + +1. `make build` +2. `make dev` +3. `docker ps` should show 2 docker containers (`auth_postgresql` and `gotrue_gotrue`) +4. That's it! Visit the [health checkendpoint](http://localhost:9999/health) to confirm that auth is running. + +## Running in production + +Running an authentication server in production is not an easy feat. We +recommend using [Supabase Auth](https://supabase.com/auth) which gets regular +security updates. + +Otherwise, please make sure you setup a process to promptly update to the +latest version. You can do that by following this repository, specifically the +[Releases](https://github.com/supabase/auth/releases) and [Security +Advisories](https://github.com/supabase/auth/security/advisories) sections. + +### Backward compatibility + +Auth uses the [Semantic Versioning](https://semver.org) scheme. Here are some +further clarifications on backward compatibility guarantees: + +**Go API compatibility** + +Auth is not meant to be used as a Go library. There are no guarantees on +backward API compatibility when used this way regardless which version number +changes. + +**Patch** + +Changes to the patch version guarantees backward compatibility with: + +- Database objects (tables, columns, indexes, functions). +- REST API +- JWT structure +- Configuration + +Guaranteed examples: + +- A column won't change its type. +- A table won't change its primary key. +- An index will not be removed. +- A uniqueness constraint will not be removed. +- A REST API will not be removed. +- Parameters to REST APIs will work equivalently as before (or better, if a bug + has been fixed). +- Configuration will not change. + +Not guaranteed examples: + +- A table may add new columns. +- Columns in a table may be reordered. +- Non-unique constraints may be removed (database level checks, null, default + values). +- JWT may add new properties. + +**Minor** + +Changes to minor version guarantees backward compatibility with: + +- REST API +- JWT structure +- Configuration + +Exceptions to these guarantees will be made only when serious security issues +are found that can't be remedied in any other way. + +Guaranteed examples: + +- Existing APIs may be deprecated but continue working for the next few minor + version releases. +- Configuration changes may become deprecated but continue working for the next + few minor version releases. +- Already issued JWTs will be accepted, but new JWTs may be with a different + structure (but usually similar). + +Not guaranteed examples: + +- Removal of JWT fields after a deprecation notice. +- Removal of certain APIs after a deprecation notice. +- Removal of sign-in with external providers, after a deprecation notice. +- Deletion, truncation, significant schema changes to tables, indexes, views, + functions. + +We aim to provide a deprecation notice in execution logs for at least two major +version releases or two weeks if multiple releases go out. Compatibility will +be guaranteed while the notice is live. + +**Major** + +Changes to the major version do not guarantee any backward compatibility with +previous versions. + +### Inherited features + +Certain inherited features from the Netlify codebase are not supported by +Supabase and they may be removed without prior notice in the future. This is a +comprehensive list of those features: + +1. Multi-tenancy via the `instances` table i.e. `GOTRUE_MULTI_INSTANCE_MODE` + configuration parameter. +2. System user (zero UUID user). +3. Super admin via the `is_super_admin` column. +4. Group information in JWTs via `GOTRUE_JWT_ADMIN_GROUP_NAME` and other + configuration fields. +5. Symmetrics JWTs. In the future it is very likely that Auth will begin + issuing asymmetric JWTs (subject to configuration), so do not rely on the + assumption that only HS256 signed JWTs will be issued long term. + +Note that this is not an exhaustive list and it may change. + +### Best practices when self-hosting + +These are some best practices to follow when self-hosting to ensure backward +compatibility with Auth: + +1. Do not modify the schema managed by Auth. You can see all of the + migrations in the `migrations` directory. +2. Do not rely on schema and structure of data in the database. Always use + Auth APIs and JWTs to infer information about users. +3. Always run Auth behind a TLS-capable proxy such as a load balancer, CDN, + nginx or other similar software. ## Configuration -You may configure GoTrue using either a configuration file named `.env`, +You may configure Auth using either a configuration file named `.env`, environment variables, or a combination of both. Environment variables are prefixed with `GOTRUE_`, and will always have precedence over values provided via file. ### Top-Level @@ -35,7 +180,9 @@ The base URL your site is located at. Currently used in combination with other s `URI_ALLOW_LIST` - `string` -A comma separated list of URIs (e.g. "https://supabase.io/welcome,io.supabase.gotruedemo://logincallback") which are permitted as valid `redirect_to` destinations, in addition to SITE_URL. Defaults to []. +A comma separated list of URIs (e.g. `"https://foo.example.com,https://*.foo.example.com,https://bar.example.com"`) which are permitted as valid `redirect_to` destinations. Defaults to []. Supports wildcard matching through globbing. e.g. `https://*.foo.example.com` will allow `https://a.foo.example.com` and `https://b.foo.example.com` to be accepted. Globbing is also supported on subdomains. e.g. `https://foo.example.com/*` will allow `https://foo.example.com/page1` and `https://foo.example.com/page2` to be accepted. + +For more common glob patterns, check out the [following link](https://pkg.go.dev/github.com/gobwas/glob#Compile). `OPERATOR_TOKEN` - `string` _Multi-instance mode only_ @@ -62,15 +209,28 @@ Header on which to rate limit the `/token` endpoint. Rate limit the number of emails sent per hr on the following endpoints: `/signup`, `/invite`, `/magiclink`, `/recover`, `/otp`, & `/user`. -`PASSWORD_MIN_LENGTH` - `int` +`GOTRUE_PASSWORD_MIN_LENGTH` - `int` Minimum password length, defaults to 6. +`GOTRUE_PASSWORD_REQUIRED_CHARACTERS` - a string of character sets separated by `:`. A password must contain at least one character of each set to be accepted. To use the `:` character escape it with `\`. + +`GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED` - `bool` + +If refresh token rotation is enabled, auth will automatically detect malicious attempts to reuse a revoked refresh token. When a malicious attempt is detected, gotrue immediately revokes all tokens that descended from the offending token. + +`GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL` - `string` + +This setting is only applicable if `GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED` is enabled. The reuse interval for a refresh token allows for exchanging the refresh token multiple times during the interval to support concurrency or offline issues. During the reuse interval, auth will not consider using a revoked token as a malicious attempt and will simply return the child refresh token. + +Only the previous revoked token can be reused. Using an old refresh token way before the current valid refresh token will trigger the reuse detection. + ### API ```properties GOTRUE_API_HOST=localhost PORT=9999 +API_EXTERNAL_URL=http://localhost:9999 ``` `API_HOST` - `string` @@ -85,6 +245,10 @@ Port number to listen on. Defaults to `8081`. Controls what endpoint Netlify can access this API on. +`API_EXTERNAL_URL` - `string` **required** + +The URL on which Gotrue might be accessed at. + `REQUEST_ID_HEADER` - `string` If you wish to inherit a request ID from the incoming request, specify the name in this value. @@ -92,35 +256,38 @@ If you wish to inherit a request ID from the incoming request, specify the name ### Database ```properties -GOTRUE_DB_DRIVER=mysql -DATABASE_URL=root@localhost/gotrue +GOTRUE_DB_DRIVER=postgres +DATABASE_URL=root@localhost/auth ``` `DB_DRIVER` - `string` **required** -Chooses what dialect of database you want. Must be `mysql`. +Chooses what dialect of database you want. Must be `postgres`. `DATABASE_URL` (no prefix) / `DB_DATABASE_URL` - `string` **required** Connection string for the database. +`GOTRUE_DB_MAX_POOL_SIZE` - `int` + +Sets the maximum number of open connections to the database. Defaults to 0 which is equivalent to an "unlimited" number of connections. + `DB_NAMESPACE` - `string` Adds a prefix to all table names. **Migrations Note** -Migrations are not applied automatically, so you will need to run them after -you've built gotrue. +Migrations are applied automatically when you run `./auth`. However, you also have the option to rerun the migrations via the following methods: -- If built locally: `./gotrue migrate` -- Using Docker: `docker run --rm gotrue gotrue migrate` +- If built locally: `./auth migrate` +- Using Docker: `docker run --rm auth gotrue migrate` ### Logging ```properties LOG_LEVEL=debug # available without GOTRUE prefix (exception) -GOTRUE_LOG_FILE=/var/log/go/gotrue.log +GOTRUE_LOG_FILE=/var/log/go/auth.log ``` `LOG_LEVEL` - `string` @@ -131,37 +298,107 @@ Controls what log levels are output. Choose from `panic`, `fatal`, `error`, `war If you wish logs to be written to a file, set `log_file` to a valid file path. -### Opentracing +### Observability -Currently, only the Datadog tracer is supported. +Auth has basic observability built in. It is able to export +[OpenTelemetry](https://opentelemetry.io) metrics and traces to a collector. + +#### Tracing + +To enable tracing configure these variables: + +`GOTRUE_TRACING_ENABLED` - `boolean` + +`GOTRUE_TRACING_EXPORTER` - `string` only `opentelemetry` supported + +Make sure you also configure the [OpenTelemetry +Exporter](https://opentelemetry.io/docs/reference/specification/protocol/exporter/) +configuration for your collector or service. + +For example, if you use +[Honeycomb.io](https://docs.honeycomb.io/getting-data-in/opentelemetry/go-distro/#using-opentelemetry-without-the-honeycomb-distribution) +you should set these standard OpenTelemetry OTLP variables: + +``` +OTEL_SERVICE_NAME=auth +OTEL_EXPORTER_OTLP_PROTOCOL=grpc +OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io:443 +OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=,x-honeycomb-dataset=auth" +``` + +#### Metrics + +To enable metrics configure these variables: + +`GOTRUE_METRICS_ENABLED` - `boolean` + +`GOTRUE_METRICS_EXPORTER` - `string` only `opentelemetry` and `prometheus` +supported + +Make sure you also configure the [OpenTelemetry +Exporter](https://opentelemetry.io/docs/reference/specification/protocol/exporter/) +configuration for your collector or service. + +If you use the `prometheus` exporter, the server host and port can be +configured using these standard OpenTelemetry variables: + +`OTEL_EXPORTER_PROMETHEUS_HOST` - IP address, default `0.0.0.0` + +`OTEL_EXPORTER_PROMETHEUS_PORT` - port number, default `9100` + +The metrics are exported on the `/` path on the server. + +If you use the `opentelemetry` exporter, the metrics are pushed to the +collector. + +For example, if you use +[Honeycomb.io](https://docs.honeycomb.io/getting-data-in/opentelemetry/go-distro/#using-opentelemetry-without-the-honeycomb-distribution) +you should set these standard OpenTelemetry OTLP variables: -```properties -GOTRUE_TRACING_ENABLED=true -GOTRUE_TRACING_HOST=127.0.0.1 -GOTRUE_TRACING_PORT=8126 -GOTRUE_TRACING_TAGS="tag1:value1,tag2:value2" -GOTRUE_SERVICE_NAME="gotrue" ``` +OTEL_SERVICE_NAME=auth +OTEL_EXPORTER_OTLP_PROTOCOL=grpc +OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io:443 +OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=,x-honeycomb-dataset=auth" +``` + +Note that Honeycomb.io requires a paid plan to ingest metrics. + +If you need to debug an issue with traces or metrics not being pushed, you can +set `DEBUG=true` to get more insights from the OpenTelemetry SDK. -`TRACING_ENABLED` - `bool` +#### Custom resource attributes -Whether tracing is enabled or not. Defaults to `false`. +When using the OpenTelemetry tracing or metrics exporter you can define custom +resource attributes using the [standard `OTEL_RESOURCE_ATTRIBUTES` environment +variable](https://opentelemetry.io/docs/reference/specification/resource/sdk/#specifying-resource-information-via-an-environment-variable). -`TRACING_HOST` - `bool` +A default attribute `auth.version` is provided containing the build version. -The tracing destination. +#### Tracing HTTP routes -`TRACING_PORT` - `bool` +All HTTP calls to the Auth API are traced. Routes use the parametrized +version of the route, and the values for the route parameters can be found as +the `http.route.params.` span attribute. -The port for the tracing host. +For example, the following request: + +``` +GET /admin/users/4acde936-82dc-4552-b851-831fb8ce0927/ +``` -`TRACING_TAGS` - `string` +will be traced as: -A comma separated list of key:value pairs. These key value pairs will be added as tags to all opentracing spans. +``` +http.method = GET +http.route = /admin/users/{user_id} +http.route.params.user_id = 4acde936-82dc-4552-b851-831fb8ce0927 +``` -`SERVICE_NAME` - `string` +#### Go runtime and HTTP metrics -The name to use for the service. +All of the Go runtime metrics are exposed. Some HTTP metrics are also collected +by default. ### JSON Web Tokens (JWT) @@ -193,7 +430,7 @@ The default group to assign all new users to. ### External Authentication Providers -We support `apple`, `azure`, `bitbucket`, `discord`, `facebook`, `github`, `gitlab`, `google`, `linkedin`, `notion`, `spotify`, `slack`, `twitch` and `twitter` for external authentication. +We support `apple`, `azure`, `bitbucket`, `discord`, `facebook`, `figma`, `github`, `gitlab`, `google`, `keycloak`, `linkedin`, `notion`, `spotify`, `slack`, `twitch`, `twitter` and `workos` for external authentication. Use the names as the keys underneath `external` to configure each separately. @@ -224,14 +461,15 @@ The URI a OAuth2 provider will redirect to with the `code` and `state` values. `EXTERNAL_X_URL` - `string` -The base URL used for constructing the URLs to request authorization and access tokens. Used by `gitlab` only. Defaults to `https://gitlab.com`. +The base URL used for constructing the URLs to request authorization and access tokens. Used by `gitlab` and `keycloak`. For `gitlab` it defaults to `https://gitlab.com`. For `keycloak` you need to set this to your instance, for example: `https://keycloak.example.com/realms/myrealm` #### Apple OAuth To try out external authentication with Apple locally, you will need to do the following: 1. Remap localhost to \ in your `/etc/hosts` config. -2. Configure gotrue to serve HTTPS traffic over localhost by replacing `ListenAndServe` in [api.go](api/api.go) with: +2. Configure auth to serve HTTPS traffic over localhost by replacing `ListenAndServe` in [api.go](internal/api/api.go) with: + ``` func (a *API) ListenAndServe(hostAndPort string) { log := logrus.WithField("component", "api") @@ -256,6 +494,7 @@ To try out external authentication with Apple locally, you will need to do the f } } ``` + 3. Generate the crt and key file. See [here](https://www.freecodecamp.org/news/how-to-get-https-working-on-your-local-development-environment-in-5-minutes-7af615770eec/) for more information. 4. Generate the `GOTRUE_EXTERNAL_APPLE_SECRET` by following this [post](https://medium.com/identity-beyond-borders/how-to-configure-sign-in-with-apple-77c61e336003)! @@ -305,21 +544,25 @@ Sets the name of the sender. Defaults to the `SMTP_ADMIN_EMAIL` if not used. If you do not require email confirmation, you may set this to `true`. Defaults to `false`. +`MAILER_OTP_EXP` - `number` + +Controls the duration an email link or otp is valid for. + `MAILER_URLPATHS_INVITE` - `string` -URL path to use in the user invite email. Defaults to `/`. +URL path to use in the user invite email. Defaults to `/verify`. `MAILER_URLPATHS_CONFIRMATION` - `string` -URL path to use in the signup confirmation email. Defaults to `/`. +URL path to use in the signup confirmation email. Defaults to `/verify`. `MAILER_URLPATHS_RECOVERY` - `string` -URL path to use in the password reset email. Defaults to `/`. +URL path to use in the password reset email. Defaults to `/verify`. `MAILER_URLPATHS_EMAIL_CHANGE` - `string` -URL path to use in the email change confirmation email. Defaults to `/`. +URL path to use in the email change confirmation email. Defaults to `/verify`. `MAILER_SUBJECTS_INVITE` - `string` @@ -343,7 +586,7 @@ Email subject to use for email change confirmation. Defaults to `Confirm Email C `MAILER_TEMPLATES_INVITE` - `string` -URL path to an email template to use when inviting a user. +URL path to an email template to use when inviting a user. (e.g. `https://www.example.com/path-to-email-template.html`) `SiteURL`, `Email`, and `ConfirmationURL` variables are available. Default Content (if template is unavailable): @@ -360,7 +603,7 @@ Default Content (if template is unavailable): `MAILER_TEMPLATES_CONFIRMATION` - `string` -URL path to an email template to use when confirming a signup. +URL path to an email template to use when confirming a signup. (e.g. `https://www.example.com/path-to-email-template.html`) `SiteURL`, `Email`, and `ConfirmationURL` variables are available. Default Content (if template is unavailable): @@ -374,7 +617,7 @@ Default Content (if template is unavailable): `MAILER_TEMPLATES_RECOVERY` - `string` -URL path to an email template to use when resetting a password. +URL path to an email template to use when resetting a password. (e.g. `https://www.example.com/path-to-email-template.html`) `SiteURL`, `Email`, and `ConfirmationURL` variables are available. Default Content (if template is unavailable): @@ -388,7 +631,7 @@ Default Content (if template is unavailable): `MAILER_TEMPLATES_MAGIC_LINK` - `string` -URL path to an email template to use when sending magic link. +URL path to an email template to use when sending magic link. (e.g. `https://www.example.com/path-to-email-template.html`) `SiteURL`, `Email`, and `ConfirmationURL` variables are available. Default Content (if template is unavailable): @@ -402,7 +645,7 @@ Default Content (if template is unavailable): `MAILER_TEMPLATES_EMAIL_CHANGE` - `string` -URL path to an email template to use when confirming the change of an email address. +URL path to an email template to use when confirming the change of an email address. (e.g. `https://www.example.com/path-to-email-template.html`) `SiteURL`, `Email`, `NewEmail`, and `ConfirmationURL` variables are available. Default Content (if template is unavailable): @@ -417,27 +660,6 @@ Default Content (if template is unavailable):

Change Email

``` -`WEBHOOK_URL` - `string` - -Url of the webhook receiver endpoint. This will be called when events like `validate`, `signup` or `login` occur. - -`WEBHOOK_SECRET` - `string` - -Shared secret to authorize webhook requests. This secret signs the [JSON Web Signature](https://tools.ietf.org/html/draft-ietf-jose-json-web-signature-41) of the request. You _should_ use this to verify the integrity of the request. Otherwise others can feed your webhook receiver with fake data. - -`WEBHOOK_RETRIES` - `number` - -How often GoTrue should try a failed hook. - -`WEBHOOK_TIMEOUT_SEC` - `number` - -Time between retries (in seconds). - -`WEBHOOK_EVENTS` - `list` - -Which events should trigger a webhook. You can provide a comma separated list. -For example to listen to all events, provide the values `validate,signup,login`. - ### Phone Auth `SMS_AUTOCONFIRM` - `bool` @@ -467,12 +689,13 @@ Then you can use your [twilio credentials](https://www.twilio.com/docs/usage/req - `SMS_TWILIO_MESSAGE_SERVICE_SID` - can be set to your twilio sender mobile number Or Messagebird credentials, which can be obtained in the [Dashboard](https://dashboard.messagebird.com/en/developers/access): + - `SMS_MESSAGEBIRD_ACCESS_KEY` - your Messagebird access key - `SMS_MESSAGEBIRD_ORIGINATOR` - SMS sender (your Messagebird phone number with + or company name) ### CAPTCHA -- If enabled, CAPTCHA will check the request body for the `hcaptcha_token` field and make a verification request to the CAPTCHA provider. +- If enabled, CAPTCHA will check the request body for the `captcha_token` field and make a verification request to the CAPTCHA provider. `SECURITY_CAPTCHA_ENABLED` - `string` @@ -480,19 +703,32 @@ Whether captcha middleware is enabled `SECURITY_CAPTCHA_PROVIDER` - `string` -for now the only option supported is: `hcaptcha` +for now the only options supported are: `hcaptcha` and `turnstile` + +- `SECURITY_CAPTCHA_SECRET` - `string` +- `SECURITY_CAPTCHA_TIMEOUT` - `string` + +Retrieve from hcaptcha or turnstile account -`SECURITY_CAPTCHA_SECRET` - `string` +### Reauthentication -Retrieve from hcaptcha account +`SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION` - `bool` + +Enforce reauthentication on password update. + +### Anonymous Sign-Ins + +`GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED` - `bool` + +Use this to enable/disable anonymous sign-ins. ## Endpoints -GoTrue exposes the following endpoints: +Auth exposes the following endpoints: ### **GET /settings** -Returns the publicly available settings for this gotrue instance. +Returns the publicly available settings for this auth instance. ```json { @@ -502,15 +738,18 @@ Returns the publicly available settings for this gotrue instance. "bitbucket": true, "discord": true, "facebook": true, + "figma": true, "github": true, "gitlab": true, "google": true, + "keycloak": true, "linkedin": true, "notion": true, "slack": true, "spotify": true, "twitch": true, - "twitter": true + "twitter": true, + "workos": true }, "disable_signup": false, "autoconfirm": false @@ -524,7 +763,7 @@ Creates (POST) or Updates (PUT) the user based on the `user_id` specified. The ` ```js headers: { - "Authorization": "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" // admin role required + "Authorization": "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" // requires a role claim that can be set in the GOTRUE_JWT_ADMIN_ROLES env var } body: @@ -543,7 +782,7 @@ body: ### **POST /admin/generate_link** -Returns the corresponding email action link based on the type specified. +Returns the corresponding email action link based on the type specified. Among other things, the response also contains the query params of the action link as separate JSON fields for convenience (along with the email OTP from which the corresponding token is generated). ```js headers: @@ -569,6 +808,10 @@ Returns ```js { "action_link": "http://localhost:9999/verify?token=TOKEN&type=TYPE&redirect_to=REDIRECT_URL", + "email_otp": "EMAIL_OTP", + "hashed_token": "TOKEN", + "verification_type": "TYPE", + "redirect_to": "REDIRECT_URL", ... } ``` @@ -577,7 +820,7 @@ Returns Register a new user with an email and password. -```js +```json { "email": "email@example.com", "password": "secret" @@ -586,7 +829,7 @@ Register a new user with an email and password. returns: -```json +```js { "id": "11111111-2222-3333-4444-5555555555555", "email": "email@example.com", @@ -611,7 +854,7 @@ Register a new user with a phone number and password. Returns: -```json +```js { "id": "11111111-2222-3333-4444-5555555555555", // if duplicate sign up, this ID will be faux "phone": "12345678", @@ -622,13 +865,40 @@ Returns: ``` if AUTOCONFIRM is enabled and the sign up is a duplicate, then the endpoint will return: -``` + +```json { "code":400, "msg":"User already registered" } ``` +### **POST /resend** + +Allows a user to resend an existing signup, sms, email_change or phone_change OTP. + +```json +{ + "email": "user@example.com", + "type": "signup" +} +``` + +```json +{ + "phone": "12345678", + "type": "sms" +} +``` + +returns: + +```json +{ + "message_id": "msgid123456" +} +``` + ### **POST /invite** Invites a new user with an email. @@ -636,7 +906,7 @@ This endpoint requires the `service_role` or `supabase_admin` JWT set as an Auth e.g. -```json +```js headers: { "Authorization" : "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" } @@ -726,7 +996,7 @@ query params: User will be logged in and redirected to: -```json +``` SITE_URL/#access_token=jwt-token-representing-the-user&token_type=bearer&expires_in=3600&refresh_token=a-refresh-token&type=invite ``` @@ -746,9 +1016,11 @@ If `"create_user": true`, user will not be automatically signed up if the user d "phone": "12345678" // follows the E.164 format "create_user": true } +``` OR +```js // exactly the same as /magiclink { "email": "email@example.com" @@ -758,7 +1030,7 @@ OR Returns: -``` +```json {} ``` @@ -815,7 +1087,7 @@ query params: body: -```json +```js // Email login { "email": "name@domain.com", @@ -884,6 +1156,7 @@ method can be used to set custom user data. Changing the email will result in a { "email": "new-email@example.com", "password": "new-password", + "phone": "+123456789", "data": { "key": "value", "number": 10, @@ -899,11 +1172,32 @@ Returns: "id": "11111111-2222-3333-4444-5555555555555", "email": "email@example.com", "email_change_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "phone": "+123456789", + "phone_change_sent_at": "2016-05-15T20:49:40.882805774-07:00", "created_at": "2016-05-15T19:53:12.368652374-07:00", "updated_at": "2016-05-15T19:53:12.368652374-07:00" } ``` +If `GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION` is enabled, the user will need to reauthenticate first. + +```json +{ + "password": "new-password", + "nonce": "123456" +} +``` + +### **GET /reauthenticate** + +Sends a nonce to the user's email (preferred) or phone. This endpoint requires the user to be logged in / authenticated first. The user needs to have either an email or phone number for the nonce to be sent successfully. + +```js +headers: { + "Authorization" : "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" +} +``` + ### **POST /logout** Logout a user (Requires authentication). @@ -918,13 +1212,14 @@ Get access_token from external oauth provider query params: ``` -provider=apple | azure | bitbucket | discord | facebook | github | gitlab | google | linkedin | notion | slack | spotify | twitch | twitter +provider=apple | azure | bitbucket | discord | facebook | figma | github | gitlab | google | keycloak | linkedin | notion | slack | spotify | twitch | twitter | workos + scopes= ``` Redirects to provider and then to `/callback` -For apple specific setup see: https://github.com/supabase/gotrue#apple-oauth +For apple specific setup see: ### **GET /callback** diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..c60730339 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,60 @@ +# Security Policy + +Auth is a project maintained by [Supabase](https://supabase.com). Below is +our security policy. + +Contact: security@supabase.io +Canonical: https://supabase.com/.well-known/security.txt + +At Supabase, we consider the security of our systems a top priority. But no +matter how much effort we put into system security, there can still be +vulnerabilities present. + +If you discover a vulnerability, we would like to know about it so we can take +steps to address it as quickly as possible. We would like to ask you to help us +better protect our clients and our systems. + +Out of scope vulnerabilities: + +- Clickjacking on pages with no sensitive actions. +- Unauthenticated/logout/login CSRF. +- Attacks requiring MITM or physical access to a user's device. +- Any activity that could lead to the disruption of our service (DoS). +- Content spoofing and text injection issues without showing an attack + vector/without being able to modify HTML/CSS. +- Email spoofing +- Missing DNSSEC, CAA, CSP headers +- Lack of Secure or HTTP only flag on non-sensitive cookies +- Deadlinks + +Please do the following: + +- E-mail your findings to security@supabase.io. +- Do not run automated scanners on our infrastructure or dashboard. If you wish + to do this, contact us and we will set up a sandbox for you. +- Do not take advantage of the vulnerability or problem you have discovered, + for example by downloading more data than necessary to demonstrate the + vulnerability or deleting or modifying other people's data, +- Do not reveal the problem to others until it has been resolved, +- Do not use attacks on physical security, social engineering, distributed + denial of service, spam or applications of third parties, and +- Do provide sufficient information to reproduce the problem, so we will be + able to resolve it as quickly as possible. Usually, the IP address or the URL + of the affected system and a description of the vulnerability will be + sufficient, but complex vulnerabilities may require further explanation. + +What we promise: + +- We will respond to your report within 3 business days with our evaluation of + the report and an expected resolution date, +- If you have followed the instructions above, we will not take any legal + action against you in regard to the report, +- We will handle your report with strict confidentiality, and not pass on your + personal details to third parties without your permission, +- We will keep you informed of the progress towards resolving the problem, +- In the public information concerning the problem reported, we will give your + name as the discoverer of the problem (unless you desire otherwise), and + +We strive to resolve all problems as quickly as possible, and we would like to +play an active role in the ultimate publication on the problem after it is +resolved. diff --git a/api/admin.go b/api/admin.go deleted file mode 100644 index 8022cda8f..000000000 --- a/api/admin.go +++ /dev/null @@ -1,347 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "strings" - "time" - - "github.com/go-chi/chi" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/sethvargo/go-password/password" -) - -type adminUserParams struct { - Aud string `json:"aud"` - Role string `json:"role"` - Email string `json:"email"` - Phone string `json:"phone"` - Password *string `json:"password"` - EmailConfirm bool `json:"email_confirm"` - PhoneConfirm bool `json:"phone_confirm"` - UserMetaData map[string]interface{} `json:"user_metadata"` - AppMetaData map[string]interface{} `json:"app_metadata"` - BanDuration string `json:"ban_duration"` -} - -func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, error) { - userID, err := uuid.FromString(chi.URLParam(r, "user_id")) - if err != nil { - return nil, badRequestError("user_id must be an UUID") - } - - logEntrySetField(r, "user_id", userID) - instanceID := getInstanceID(r.Context()) - - u, err := models.FindUserByInstanceIDAndID(a.db, instanceID, userID) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") - } - return nil, internalServerError("Database error loading user").WithInternalError(err) - } - - return withUser(r.Context(), u), nil -} - -func (a *API) getAdminParams(r *http.Request) (*adminUserParams, error) { - params := adminUserParams{} - err := json.NewDecoder(r.Body).Decode(¶ms) - if err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) - } - return ¶ms, nil -} - -// adminUsers responds with a list of all users in a given audience -func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - instanceID := getInstanceID(ctx) - aud := a.requestAud(ctx, r) - - pageParams, err := paginate(r) - if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) - } - - sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{models.SortField{Name: models.CreatedAt, Dir: models.Descending}}) - if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) - } - - filter := r.URL.Query().Get("filter") - - users, err := models.FindUsersInAudience(a.db, instanceID, aud, pageParams, sortParams, filter) - if err != nil { - return internalServerError("Database error finding users").WithInternalError(err) - } - addPaginationHeaders(w, r, pageParams) - - return sendJSON(w, http.StatusOK, map[string]interface{}{ - "users": users, - "aud": aud, - }) -} - -// adminUserGet returns information about a single user -func (a *API) adminUserGet(w http.ResponseWriter, r *http.Request) error { - user := getUser(r.Context()) - - return sendJSON(w, http.StatusOK, user) -} - -// adminUserUpdate updates a single user object -func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - user := getUser(ctx) - adminUser := getAdminUser(ctx) - instanceID := getInstanceID(ctx) - params, err := a.getAdminParams(r) - config := getConfig(ctx) - if err != nil { - return err - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if params.Role != "" { - if terr := user.SetRole(tx, params.Role); terr != nil { - return terr - } - } - - if params.EmailConfirm { - if terr := user.Confirm(tx); terr != nil { - return terr - } - } - - if params.PhoneConfirm { - if terr := user.ConfirmPhone(tx); terr != nil { - return terr - } - } - - if params.Password != nil { - if len(*params.Password) < config.PasswordMinLength { - return invalidPasswordLengthError(config) - } - - if terr := user.UpdatePassword(tx, *params.Password); terr != nil { - return terr - } - } - - if params.Email != "" { - if terr := user.SetEmail(tx, params.Email); terr != nil { - return terr - } - } - - if params.Phone != "" { - if terr := user.SetPhone(tx, params.Phone); terr != nil { - return terr - } - } - - if params.AppMetaData != nil { - if terr := user.UpdateAppMetaData(tx, params.AppMetaData); terr != nil { - return terr - } - } - - if params.UserMetaData != nil { - if terr := user.UpdateUserMetaData(tx, params.UserMetaData); terr != nil { - return terr - } - } - - if params.BanDuration != "" { - if params.BanDuration == "none" { - user.BannedUntil = nil - } else { - duration, terr := time.ParseDuration(params.BanDuration) - if terr != nil { - return badRequestError("Invalid format for ban_duration: %v", terr) - } - t := time.Now().Add(duration) - user.BannedUntil = &t - } - if terr := user.UpdateBannedUntil(tx); terr != nil { - return terr - } - } - - if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserModifiedAction, map[string]interface{}{ - "user_id": user.ID, - "user_email": user.Email, - "user_phone": user.Phone, - }); terr != nil { - return terr - } - return nil - }) - - if err != nil { - if errors.Is(err, invalidPasswordLengthError(config)) { - return err - } - if strings.Contains(err.Error(), "Invalid format for ban_duration") { - return err - } - return internalServerError("Error updating user").WithInternalError(err) - } - - return sendJSON(w, http.StatusOK, user) -} - -// adminUserCreate creates a new user based on the provided data -func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - instanceID := getInstanceID(ctx) - adminUser := getAdminUser(ctx) - params, err := a.getAdminParams(r) - if err != nil { - return err - } - - aud := a.requestAud(ctx, r) - if params.Aud != "" { - aud = params.Aud - } - - if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") - } - - if params.Email != "" { - if err := a.validateEmail(ctx, params.Email); err != nil { - return err - } - if exists, err := models.IsDuplicatedEmail(a.db, instanceID, params.Email, aud); err != nil { - return internalServerError("Database error checking email").WithInternalError(err) - } else if exists { - return unprocessableEntityError("Email address already registered by another user") - } - } - - if params.Phone != "" { - params.Phone = a.formatPhoneNumber(params.Phone) - if isValid := a.validateE164Format(params.Phone); !isValid { - return unprocessableEntityError("Invalid phone format") - } - if exists, err := models.IsDuplicatedPhone(a.db, instanceID, params.Phone, aud); err != nil { - return internalServerError("Database error checking phone").WithInternalError(err) - } else if exists { - return unprocessableEntityError("Phone number already registered by another user") - } - } - - if params.Password == nil || *params.Password == "" { - password, err := password.Generate(64, 10, 0, false, true) - if err != nil { - return internalServerError("Error generating password").WithInternalError(err) - } - params.Password = &password - } - - user, err := models.NewUser(instanceID, params.Email, *params.Password, aud, params.UserMetaData) - if err != nil { - return internalServerError("Error creating user").WithInternalError(err) - } - if params.Phone != "" { - user.Phone = storage.NullString(params.Phone) - } - if user.AppMetaData == nil { - user.AppMetaData = make(map[string]interface{}) - } - user.AppMetaData["provider"] = "email" - user.AppMetaData["providers"] = []string{"email"} - - if params.BanDuration != "" { - duration, terr := time.ParseDuration(params.BanDuration) - if terr != nil { - return badRequestError("Invalid format for ban_duration: %v", terr) - } - t := time.Now().Add(duration) - user.BannedUntil = &t - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserSignedUpAction, map[string]interface{}{ - "user_id": user.ID, - "user_email": user.Email, - "user_phone": user.Phone, - }); terr != nil { - return terr - } - - if terr := tx.Create(user); terr != nil { - return terr - } - - role := config.JWT.DefaultGroupName - if params.Role != "" { - role = params.Role - } - if terr := user.SetRole(tx, role); terr != nil { - return terr - } - - if params.EmailConfirm { - if terr := user.Confirm(tx); terr != nil { - return terr - } - } - - if params.PhoneConfirm { - if terr := user.ConfirmPhone(tx); terr != nil { - return terr - } - } - - return nil - }) - - if err != nil { - if strings.Contains(err.Error(), "Invalid format for ban_duration") { - return err - } - return internalServerError("Database error creating new user").WithInternalError(err) - } - - return sendJSON(w, http.StatusOK, user) -} - -// adminUserDelete delete a user -func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - user := getUser(ctx) - instanceID := getInstanceID(ctx) - adminUser := getAdminUser(ctx) - - err := a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserDeletedAction, map[string]interface{}{ - "user_id": user.ID, - "user_email": user.Email, - "user_phone": user.Phone, - }); terr != nil { - return internalServerError("Error recording audit log entry").WithInternalError(terr) - } - - if terr := tx.Destroy(user); terr != nil { - return internalServerError("Database error deleting user").WithInternalError(terr) - } - return nil - }) - if err != nil { - return err - } - - return sendJSON(w, http.StatusOK, map[string]interface{}{}) -} diff --git a/api/admin_test.go b/api/admin_test.go deleted file mode 100644 index 9c421630a..000000000 --- a/api/admin_test.go +++ /dev/null @@ -1,599 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type AdminTestSuite struct { - suite.Suite - User *models.User - API *API - Config *conf.Configuration - - token string - instanceID uuid.UUID -} - -func TestAdmin(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &AdminTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *AdminTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) - ts.Config.External.Email.Enabled = true - ts.token = ts.makeSuperAdmin("") -} - -func (ts *AdminTestSuite) makeSuperAdmin(email string) string { - u, err := models.NewUser(ts.instanceID, email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) - require.NoError(ts.T(), err, "Error making new user") - - u.Role = "supabase_admin" - - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) - require.NoError(ts.T(), err, "Error generating access token") - - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { - return []byte(ts.Config.JWT.Secret), nil - }) - require.NoError(ts.T(), err, "Error parsing token") - - return token -} - -func (ts *AdminTestSuite) makeSystemUser() string { - u := models.NewSystemUser(uuid.Nil, ts.Config.JWT.Aud) - u.Role = "service_role" - - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) - require.NoError(ts.T(), err, "Error generating access token") - - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { - return []byte(ts.Config.JWT.Secret), nil - }) - require.NoError(ts.T(), err, "Error parsing token") - - return token -} - -// TestAdminUsersUnauthorized tests API /admin/users route without authentication -func (ts *AdminTestSuite) TestAdminUsersUnauthorized() { - req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers() { - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - assert.Equal(ts.T(), "; rel=\"last\"", w.HeaderMap.Get("Link")) - assert.Equal(ts.T(), "0", w.HeaderMap.Get("X-Total-Count")) -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers_Pagination() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - u, err = models.NewUser(ts.instanceID, "test2@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users?per_page=1", nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - assert.Equal(ts.T(), "; rel=\"next\", ; rel=\"last\"", w.HeaderMap.Get("Link")) - assert.Equal(ts.T(), "2", w.HeaderMap.Get("X-Total-Count")) - - data := make(map[string]interface{}) - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - for _, user := range data["users"].([]interface{}) { - assert.NotEmpty(ts.T(), user) - } -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers_SortAsc() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // if the created_at times are the same, then the sort order is not guaranteed - time.Sleep(1 * time.Second) - u, err = models.NewUser(ts.instanceID, "test2@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) - qv := req.URL.Query() - qv.Set("sort", "created_at asc") - req.URL.RawQuery = qv.Encode() - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := struct { - Users []*models.User `json:"users"` - Aud string `json:"aud"` - }{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - require.Len(ts.T(), data.Users, 2) - assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) - assert.Equal(ts.T(), "test2@example.com", data.Users[1].GetEmail()) -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers_SortDesc() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // if the created_at times are the same, then the sort order is not guaranteed - time.Sleep(1 * time.Second) - u, err = models.NewUser(ts.instanceID, "test2@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := struct { - Users []*models.User `json:"users"` - Aud string `json:"aud"` - }{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - require.Len(ts.T(), data.Users, 2) - assert.Equal(ts.T(), "test2@example.com", data.Users[0].GetEmail()) - assert.Equal(ts.T(), "test1@example.com", data.Users[1].GetEmail()) -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers_FilterEmail() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=test1", nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := struct { - Users []*models.User `json:"users"` - Aud string `json:"aud"` - }{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - require.Len(ts.T(), data.Users, 1) - assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) -} - -// TestAdminUsers tests API /admin/users route -func (ts *AdminTestSuite) TestAdminUsers_FilterName() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - u, err = models.NewUser(ts.instanceID, "test2@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=User", nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := struct { - Users []*models.User `json:"users"` - Aud string `json:"aud"` - }{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - require.Len(ts.T(), data.Users, 1) - assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) -} - -// TestAdminUserCreate tests API /admin/user route (POST) -func (ts *AdminTestSuite) TestAdminUserCreate() { - cases := []struct { - desc string - params map[string]interface{} - expected map[string]interface{} - }{ - { - desc: "With password", - params: map[string]interface{}{ - "email": "test1@example.com", - "phone": "123456789", - "password": "test1", - }, - expected: map[string]interface{}{ - "email": "test1@example.com", - "phone": "123456789", - "isAuthenticated": true, - }, - }, - { - desc: "Without password", - params: map[string]interface{}{ - "email": "test2@example.com", - "phone": "", - }, - expected: map[string]interface{}{ - "email": "test2@example.com", - "phone": "", - "isAuthenticated": false, - }, - }, - { - desc: "With empty string password", - params: map[string]interface{}{ - "email": "test3@example.com", - "phone": "", - "password": "", - }, - expected: map[string]interface{}{ - "email": "test3@example.com", - "phone": "", - "isAuthenticated": false, - }, - }, - { - desc: "Ban created user", - params: map[string]interface{}{ - "email": "test4@example.com", - "phone": "", - "password": "test1", - "ban_duration": "24h", - }, - expected: map[string]interface{}{ - "email": "test4@example.com", - "phone": "", - "isAuthenticated": true, - }, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - ts.Config.External.Phone.Enabled = true - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := models.User{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - assert.Equal(ts.T(), c.expected["email"], data.GetEmail()) - assert.Equal(ts.T(), c.expected["phone"], data.GetPhone()) - assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) - assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) - - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, data.GetEmail(), ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - var expectedPassword string - if _, ok := c.params["password"]; ok { - expectedPassword = fmt.Sprintf("%v", c.params["password"]) - } - - assert.Equal(ts.T(), c.expected["isAuthenticated"], u.Authenticate(expectedPassword)) - }) - } -} - -// TestAdminUserGet tests API /admin/user route (GET) -func (ts *AdminTestSuite) TestAdminUserGet() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test Get User"}) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/users/%s", u.ID), nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := make(map[string]interface{}) - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - assert.Equal(ts.T(), data["email"], "test1@example.com") - assert.NotNil(ts.T(), data["app_metadata"]) - assert.NotNil(ts.T(), data["user_metadata"]) - md := data["user_metadata"].(map[string]interface{}) - assert.Len(ts.T(), md, 1) - assert.Equal(ts.T(), "Test Get User", md["full_name"]) -} - -// TestAdminUserUpdate tests API /admin/user route (UPDATE) -func (ts *AdminTestSuite) TestAdminUserUpdate() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "role": "testing", - "app_metadata": map[string]interface{}{ - "roles": []string{"writer", "editor"}, - }, - "user_metadata": map[string]interface{}{ - "name": "David", - }, - "ban_duration": "24h", - })) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := models.User{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - assert.Equal(ts.T(), "testing", data.Role) - assert.NotNil(ts.T(), data.UserMetaData) - assert.Equal(ts.T(), "David", data.UserMetaData["name"]) - - assert.NotNil(ts.T(), data.AppMetaData) - assert.Len(ts.T(), data.AppMetaData["roles"], 2) - assert.Contains(ts.T(), data.AppMetaData["roles"], "writer") - assert.Contains(ts.T(), data.AppMetaData["roles"], "editor") - assert.NotNil(ts.T(), data.BannedUntil) -} - -// TestAdminUserUpdate tests API /admin/user route (UPDATE) as system user -func (ts *AdminTestSuite) TestAdminUserUpdateAsSystemUser() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "role": "testing", - "app_metadata": map[string]interface{}{ - "roles": []string{"writer", "editor"}, - }, - "user_metadata": map[string]interface{}{ - "name": "David", - }, - })) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) - - token := ts.makeSystemUser() - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := make(map[string]interface{}) - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - assert.Equal(ts.T(), data["role"], "testing") - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test1@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.Equal(ts.T(), u.Role, "testing") - require.NotNil(ts.T(), u.UserMetaData) - require.Contains(ts.T(), u.UserMetaData, "name") - assert.Equal(ts.T(), u.UserMetaData["name"], "David") - require.NotNil(ts.T(), u.AppMetaData) - require.Contains(ts.T(), u.AppMetaData, "roles") - assert.Len(ts.T(), u.AppMetaData["roles"], 2) - assert.Contains(ts.T(), u.AppMetaData["roles"], "writer") - assert.Contains(ts.T(), u.AppMetaData["roles"], "editor") -} - -func (ts *AdminTestSuite) TestAdminUserUpdatePasswordFailed() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) - ts.Config.PasswordMinLength = 6 - ts.Run("Password doesn't meet minimum length", func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "password": "", - })) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) - }) -} - -func (ts *AdminTestSuite) TestAdminUserUpdateBannedUntilFailed() { - u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) - ts.Config.PasswordMinLength = 6 - ts.Run("Incorrect format for ban_duration", func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "ban_duration": "24", - })) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusBadRequest, w.Code) - }) -} - -// TestAdminUserDelete tests API /admin/user route (DELETE) -func (ts *AdminTestSuite) TestAdminUserDelete() { - u, err := models.NewUser(ts.instanceID, "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error making new user") - require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), nil) - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) -} - -func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { - var cases = []struct { - desc string - customConfig *conf.Configuration - userData map[string]interface{} - expected int - }{ - { - "Email Signups Disabled", - &conf.Configuration{ - JWT: ts.Config.JWT, - External: conf.ProviderConfiguration{ - Email: conf.EmailProviderConfiguration{ - Enabled: false, - }, - }, - }, - map[string]interface{}{ - "email": "test1@example.com", - "password": "test1", - }, - http.StatusOK, - }, - { - "Phone Signups Disabled", - &conf.Configuration{ - JWT: ts.Config.JWT, - External: conf.ProviderConfiguration{ - Phone: conf.PhoneProviderConfiguration{ - Enabled: false, - }, - }, - }, - map[string]interface{}{ - "phone": "123456789", - "password": "test1", - }, - http.StatusOK, - }, - { - "All Signups Disabled", - &conf.Configuration{ - JWT: ts.Config.JWT, - DisableSignup: true, - }, - map[string]interface{}{ - "email": "test2@example.com", - "password": "test2", - }, - http.StatusOK, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - // Initialize user data - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.userData)) - - // Setup request - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - - *ts.Config = *c.customConfig - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), c.expected, w.Code) - }) - } -} diff --git a/api/api.go b/api/api.go deleted file mode 100644 index 6ea197856..000000000 --- a/api/api.go +++ /dev/null @@ -1,282 +0,0 @@ -package api - -import ( - "context" - "net/http" - "os" - "os/signal" - "regexp" - "syscall" - "time" - - "github.com/didip/tollbooth/v5" - "github.com/didip/tollbooth/v5/limiter" - "github.com/go-chi/chi" - "github.com/gofrs/uuid" - "github.com/imdario/mergo" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/mailer" - "github.com/netlify/gotrue/storage" - "github.com/rs/cors" - "github.com/sebest/xff" - "github.com/sirupsen/logrus" -) - -const ( - audHeaderName = "X-JWT-AUD" - defaultVersion = "unknown version" -) - -var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`) - -// API is the main REST API -type API struct { - handler http.Handler - db *storage.Connection - config *conf.GlobalConfiguration - version string -} - -// ListenAndServe starts the REST API -func (a *API) ListenAndServe(hostAndPort string) { - log := logrus.WithField("component", "api") - server := &http.Server{ - Addr: hostAndPort, - Handler: a.handler, - } - - done := make(chan struct{}) - defer close(done) - go func() { - waitForTermination(log, done) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - server.Shutdown(ctx) - }() - - if err := server.ListenAndServe(); err != http.ErrServerClosed { - log.WithError(err).Fatal("http server listen failed") - } -} - -// WaitForShutdown blocks until the system signals termination or done has a value -func waitForTermination(log logrus.FieldLogger, done <-chan struct{}) { - signals := make(chan os.Signal, 1) - signal.Notify(signals, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - select { - case sig := <-signals: - log.Infof("Triggering shutdown from signal %s", sig) - case <-done: - log.Infof("Shutting down...") - } -} - -// NewAPI instantiates a new REST API -func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection) *API { - return NewAPIWithVersion(context.Background(), globalConfig, db, defaultVersion) -} - -// NewAPIWithVersion creates a new REST API using the specified version -func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API { - api := &API{config: globalConfig, db: db, version: version} - - xffmw, _ := xff.Default() - logger := newStructuredLogger(logrus.StandardLogger()) - - r := newRouter() - r.UseBypass(xffmw.Handler) - r.Use(addRequestID(globalConfig)) - r.Use(recoverer) - r.UseBypass(tracer) - - r.Get("/health", api.HealthCheck) - - r.Route("/callback", func(r *router) { - r.UseBypass(logger) - r.Use(api.loadOAuthState) - - if globalConfig.MultiInstanceMode { - r.Use(api.loadInstanceConfig) - } - r.Get("/", api.ExternalProviderCallback) - r.Post("/", api.ExternalProviderCallback) - }) - - r.Route("/", func(r *router) { - r.UseBypass(logger) - - if globalConfig.MultiInstanceMode { - r.Use(api.loadJWSSignatureHeader) - r.Use(api.loadInstanceConfig) - } - - r.Get("/settings", api.Settings) - - r.Get("/authorize", api.ExternalProviderRedirect) - - sharedLimiter := api.limitEmailSentHandler() - r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite) - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/signup", api.Signup) - r.With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) - - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) - - r.With(api.limitHandler( - // Allow requests at a rate of 30 per 5 minutes. - tollbooth.NewLimiter(30.0/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).Post("/token", api.Token) - - r.With(api.limitHandler( - // Allow requests at a rate of 30 per 5 minutes. - tollbooth.NewLimiter(30.0/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).Route("/verify", func(r *router) { - r.Get("/", api.Verify) - r.Post("/", api.Verify) - }) - - r.With(api.requireAuthentication).Post("/logout", api.Logout) - - r.Route("/user", func(r *router) { - r.Use(api.requireAuthentication) - r.Get("/", api.UserGet) - r.With(sharedLimiter).Put("/", api.UserUpdate) - }) - - r.Route("/admin", func(r *router) { - r.Use(api.requireAdminCredentials) - - r.Route("/audit", func(r *router) { - r.Get("/", api.adminAuditLog) - }) - - r.Route("/users", func(r *router) { - r.Get("/", api.adminUsers) - r.Post("/", api.adminUserCreate) - - r.Route("/{user_id}", func(r *router) { - r.Use(api.loadUser) - - r.Get("/", api.adminUserGet) - r.Put("/", api.adminUserUpdate) - r.Delete("/", api.adminUserDelete) - }) - }) - - r.Post("/generate_link", api.GenerateLink) - }) - - r.Route("/saml", func(r *router) { - r.Route("/acs", func(r *router) { - r.Use(api.loadSAMLState) - r.Post("/", api.ExternalProviderCallback) - }) - - r.Get("/metadata", api.SAMLMetadata) - }) - }) - - if globalConfig.MultiInstanceMode { - // Operator microservice API - r.WithBypass(logger).Get("/", api.GetAppManifest) - r.Route("/instances", func(r *router) { - r.UseBypass(logger) - - r.Post("/", api.CreateInstance) - r.Route("/{instance_id}", func(r *router) { - r.Use(api.loadInstance) - - r.Get("/", api.GetInstance) - r.Put("/", api.UpdateInstance) - r.Delete("/", api.DeleteInstance) - }) - }) - } - - corsHandler := cors.New(cors.Options{ - AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", audHeaderName, useCookieHeader}, - AllowCredentials: true, - }) - - api.handler = corsHandler.Handler(chi.ServerBaseContext(ctx, r)) - return api -} - -// NewAPIFromConfigFile creates a new REST API using the provided configuration file. -func NewAPIFromConfigFile(filename string, version string) (*API, *conf.Configuration, error) { - globalConfig, err := conf.LoadGlobal(filename) - if err != nil { - return nil, nil, err - } - - config, err := conf.LoadConfig(filename) - if err != nil { - return nil, nil, err - } - - ctx, err := WithInstanceConfig(context.Background(), config, uuid.Nil) - if err != nil { - logrus.Fatalf("Error loading instance config: %+v", err) - } - - db, err := storage.Dial(globalConfig) - if err != nil { - return nil, nil, err - } - - return NewAPIWithVersion(ctx, globalConfig, db, version), config, nil -} - -// HealthCheck endpoint indicates if the gotrue api service is available -func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { - return sendJSON(w, http.StatusOK, map[string]string{ - "version": a.version, - "name": "GoTrue", - "description": "GoTrue is a user registration and authentication API", - }) -} - -// WithInstanceConfig adds the instanceID and tenant config to the context -func WithInstanceConfig(ctx context.Context, config *conf.Configuration, instanceID uuid.UUID) (context.Context, error) { - ctx = withConfig(ctx, config) - ctx = withInstanceID(ctx, instanceID) - return ctx, nil -} - -// Mailer returns NewMailer with the current tenant config -func (a *API) Mailer(ctx context.Context) mailer.Mailer { - config := a.getConfig(ctx) - return mailer.NewMailer(config) -} - -func (a *API) getConfig(ctx context.Context) *conf.Configuration { - obj := ctx.Value(configKey) - if obj == nil { - return nil - } - - config := obj.(*conf.Configuration) - - // Merge global & per-instance external config for multi-instance mode - if a.config.MultiInstanceMode { - extConfig := (*a.config).External - if err := mergo.MergeWithOverwrite(&extConfig, config.External); err != nil { - return nil - } - config.External = extConfig - - // Merge global & per-instance smtp config for multi-instance mode - smtpConfig := (*a.config).SMTP - if err := mergo.MergeWithOverwrite(&smtpConfig, config.SMTP); err != nil { - return nil - } - config.SMTP = smtpConfig - } - - return config -} diff --git a/api/api_test.go b/api/api_test.go deleted file mode 100644 index a5522f60e..000000000 --- a/api/api_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package api - -import ( - "context" - "math/rand" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/netlify/gotrue/storage/test" - "github.com/stretchr/testify/require" -) - -const ( - apiTestVersion = "1" - apiTestConfig = "../hack/test.env" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -// setupAPIForTest creates a new API to run tests with. -// Using this function allows us to keep track of the database connection -// and cleaning up data between tests. -func setupAPIForTest() (*API, *conf.Configuration, error) { - return setupAPIForTestWithCallback(nil) -} - -func setupAPIForMultiinstanceTest() (*API, *conf.Configuration, error) { - cb := func(gc *conf.GlobalConfiguration, c *conf.Configuration, conn *storage.Connection) (uuid.UUID, error) { - gc.MultiInstanceMode = true - return uuid.Nil, nil - } - - return setupAPIForTestWithCallback(cb) -} - -func setupAPIForTestForInstance() (*API, *conf.Configuration, uuid.UUID, error) { - instanceID := uuid.Must(uuid.NewV4()) - cb := func(gc *conf.GlobalConfiguration, c *conf.Configuration, conn *storage.Connection) (uuid.UUID, error) { - err := conn.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: c, - }) - return instanceID, err - } - - api, conf, err := setupAPIForTestWithCallback(cb) - if err != nil { - return nil, nil, uuid.Nil, err - } - return api, conf, instanceID, nil -} - -func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *conf.Configuration, *storage.Connection) (uuid.UUID, error)) (*API, *conf.Configuration, error) { - globalConfig, err := conf.LoadGlobal(apiTestConfig) - if err != nil { - return nil, nil, err - } - - conn, err := test.SetupDBConnection(globalConfig) - if err != nil { - return nil, nil, err - } - - config, err := conf.LoadConfig(apiTestConfig) - if err != nil { - conn.Close() - return nil, nil, err - } - - instanceID := uuid.Nil - if cb != nil { - instanceID, err = cb(globalConfig, config, conn) - if err != nil { - conn.Close() - return nil, nil, err - } - } - - ctx, err := WithInstanceConfig(context.Background(), config, instanceID) - if err != nil { - conn.Close() - return nil, nil, err - } - - return NewAPIWithVersion(ctx, globalConfig, conn, apiTestVersion), config, nil -} - -func TestEmailEnabledByDefault(t *testing.T) { - api, _, err := setupAPIForTest() - require.NoError(t, err) - - require.True(t, api.config.External.Email.Enabled) -} - -const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func randStringBytes(n int) string { - b := make([]byte, n) - for i := range b { - b[i] = letterBytes[rand.Int63()%int64(len(letterBytes))] - } - return string(b) -} diff --git a/api/auth.go b/api/auth.go deleted file mode 100644 index 35d82444b..000000000 --- a/api/auth.go +++ /dev/null @@ -1,73 +0,0 @@ -package api - -import ( - "context" - "fmt" - "net/http" - "time" - - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// requireAuthentication checks incoming requests for tokens presented using the Authorization header -func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (context.Context, error) { - token, err := a.extractBearerToken(w, r) - config := getConfig(r.Context()) - if err != nil { - a.clearCookieTokens(config, w) - return nil, err - } - - return a.parseJWTClaims(token, r, w) -} - -func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) { - // Find the administrative user - claims := getClaims(ctx) - if claims == nil { - fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") - } - - adminRoles := a.getConfig(ctx).JWT.AdminRoles - - if isStringInSlice(claims.Role, adminRoles) { - // successful authentication - return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil - } - - fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") -} - -func (a *API) extractBearerToken(w http.ResponseWriter, r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", unauthorizedError("This endpoint requires a Bearer token") - } - - matches := bearerRegexp.FindStringSubmatch(authHeader) - if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") - } - - return matches[1], nil -} - -func (a *API) parseJWTClaims(bearer string, r *http.Request, w http.ResponseWriter) (context.Context, error) { - ctx := r.Context() - config := a.getConfig(ctx) - - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - token, err := p.ParseWithClaims(bearer, &GoTrueClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(config.JWT.Secret), nil - }) - if err != nil { - a.clearCookieTokens(config, w) - return nil, unauthorizedError("Invalid token: %v", err) - } - - return withToken(ctx, token), nil -} diff --git a/api/errors.go b/api/errors.go deleted file mode 100644 index f839357de..000000000 --- a/api/errors.go +++ /dev/null @@ -1,274 +0,0 @@ -package api - -import ( - "context" - "fmt" - "net/http" - "os" - "runtime/debug" - - "github.com/netlify/gotrue/conf" - "github.com/pkg/errors" -) - -// Common error messages during signup flow -var ( - DuplicateEmailMsg = "A user with this email address has already been registered" - UserExistsError error = errors.New("User already exists") -) - -var oauthErrorMap = map[int]string{ - http.StatusBadRequest: "invalid_request", - http.StatusUnauthorized: "unauthorized_client", - http.StatusForbidden: "access_denied", - http.StatusInternalServerError: "server_error", - http.StatusServiceUnavailable: "temporarily_unavailable", -} - -// OAuthError is the JSON handler for OAuth2 error responses -type OAuthError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OAuthError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OAuthError) WithInternalError(err error) *OAuthError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OAuthError) Cause() error { - if e.InternalError != nil { - return e.InternalError - } - return e -} - -func invalidPasswordLengthError(config *conf.Configuration) *HTTPError { - return unprocessableEntityError(fmt.Sprintf("Password should be at least %d characters", config.PasswordMinLength)) -} - -func invalidSignupError(config *conf.Configuration) *HTTPError { - var msg string - if config.External.Email.Enabled && config.External.Phone.Enabled { - msg = "To signup, please provide your email or phone number" - } else if config.External.Email.Enabled { - msg = "To signup, please provide your email" - } else if config.External.Phone.Enabled { - msg = "To signup, please provide your phone number" - } else { - // 3rd party OAuth signups - msg = "To signup, please provide required fields" - } - return unprocessableEntityError(msg) -} - -func oauthError(err string, description string) *OAuthError { - return &OAuthError{Err: err, Description: description} -} - -func badRequestError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, fmtString, args...) -} - -func internalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, fmtString, args...) -} - -func notFoundError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, fmtString, args...) -} - -func acceptedTokenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusAccepted, fmtString, args...) -} - -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusGone, fmtString, args...) -} - -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) -} - -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusForbidden, fmtString, args...) -} - -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnprocessableEntity, fmtString, args...) -} - -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusTooManyRequests, fmtString, args...) -} - -// HTTPError is an error with a message and an HTTP status code. -type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` - ErrorID string `json:"error_id,omitempty"` -} - -func (e *HTTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%d: %s", e.Code, e.Message) -} - -func (e *HTTPError) Is(target error) bool { - return e.Error() == target.Error() -} - -// Cause returns the root cause error -func (e *HTTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError - } - return e -} - -// WithInternalError adds internal error information to the error -func (e *HTTPError) WithInternalError(err error) *HTTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { - return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - } -} - -// OTPError is a custom error struct for phone auth errors -type OTPError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OTPError) WithInternalError(err error) *OTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError - } - return e -} - -func otpError(err string, description string) *OTPError { - return &OTPError{Err: err, Description: description} -} - -// Recoverer is a middleware that recovers from panics, logs the panic (and a -// backtrace), and returns a HTTP 500 (Internal Server Error) status if -// possible. Recoverer prints a request ID if one is provided. -func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) { - defer func() { - if rvr := recover(); rvr != nil { - - logEntry := getLogEntry(r) - if logEntry != nil { - logEntry.Panic(rvr, debug.Stack()) - } else { - fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) - debug.PrintStack() - } - - se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - } - handleError(se, w, r) - } - }() - - return nil, nil -} - -// ErrorCause is an error interface that contains the method Cause() for returning root cause errors -type ErrorCause interface { - Cause() error -} - -func handleError(err error, w http.ResponseWriter, r *http.Request) { - log := getLogEntry(r) - errorID := getRequestID(r.Context()) - switch e := err.(type) { - case *HTTPError: - if e.Code >= http.StatusInternalServerError { - e.ErrorID = errorID - // this will get us the stack trace too - log.WithError(e.Cause()).Error(e.Error()) - } else { - log.WithError(e.Cause()).Info(e.Error()) - } - if jsonErr := sendJSON(w, e.Code, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OAuthError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OTPError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case ErrorCause: - handleError(e.Cause(), w, r) - default: - log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) - // hide real error details from response to prevent info leaks - w.WriteHeader(http.StatusInternalServerError) - if _, writeErr := w.Write([]byte(`{"code":500,"msg":"Internal server error","error_id":"` + errorID + `"}`)); writeErr != nil { - log.WithError(writeErr).Error("Error writing generic error message") - } - } -} diff --git a/api/external.go b/api/external.go deleted file mode 100644 index 328d42439..000000000 --- a/api/external.go +++ /dev/null @@ -1,512 +0,0 @@ -package api - -import ( - "context" - "errors" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/api/provider" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/sirupsen/logrus" -) - -// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow -type ExternalProviderClaims struct { - NetlifyMicroserviceClaims - Provider string `json:"provider"` - InviteToken string `json:"invite_token,omitempty"` - Referrer string `json:"referrer,omitempty"` -} - -// ExternalSignupParams are the parameters the Signup endpoint accepts -type ExternalSignupParams struct { - Provider string `json:"provider"` - Code string `json:"code"` -} - -// ExternalProviderRedirect redirects the request to the corresponding oauth provider -func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - providerType := r.URL.Query().Get("provider") - scopes := r.URL.Query().Get("scopes") - - p, err := a.Provider(ctx, providerType, scopes) - if err != nil { - return badRequestError("Unsupported provider: %+v", err).WithInternalError(err) - } - - inviteToken := r.URL.Query().Get("invite_token") - if inviteToken != "" { - _, userErr := models.FindUserByConfirmationToken(a.db, inviteToken) - if userErr != nil { - if models.IsNotFoundError(userErr) { - return notFoundError(userErr.Error()) - } - return internalServerError("Database error finding user").WithInternalError(userErr) - } - } - - redirectURL := a.getRedirectURLOrReferrer(r, r.URL.Query().Get("redirect_to")) - log := getLogEntry(r) - log.WithField("provider", providerType).Info("Redirecting to external provider") - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{ - NetlifyMicroserviceClaims: NetlifyMicroserviceClaims{ - StandardClaims: jwt.StandardClaims{ - ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), - }, - SiteURL: config.SiteURL, - InstanceID: getInstanceID(ctx).String(), - NetlifyID: getNetlifyID(ctx), - }, - Provider: providerType, - InviteToken: inviteToken, - Referrer: redirectURL, - }) - tokenString, err := token.SignedString([]byte(config.JWT.Secret)) - if err != nil { - return internalServerError("Error creating state").WithInternalError(err) - } - - var authURL string - switch externalProvider := p.(type) { - case *provider.TwitterProvider: - authURL = externalProvider.AuthCodeURL(tokenString) - err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w) - if err != nil { - return internalServerError("Error storing request token in session").WithInternalError(err) - } - default: - authURL = p.AuthCodeURL(tokenString) - } - - http.Redirect(w, r, authURL, http.StatusFound) - return nil -} - -// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow -func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { - a.redirectErrors(a.internalExternalProviderCallback, w, r) - return nil -} - -func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - instanceID := getInstanceID(ctx) - - providerType := getExternalProviderType(ctx) - var userData *provider.UserProvidedData - var providerToken string - if providerType == "saml" { - samlUserData, err := a.samlCallback(ctx, r) - if err != nil { - return err - } - userData = samlUserData - } else if providerType == "twitter" { - // future OAuth1.0 providers will use this method - oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerToken = oAuthResponseData.token - } else { - oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerToken = oAuthResponseData.token - } - - var user *models.User - var token *AccessTokenResponse - err := a.db.Transaction(func(tx *storage.Connection) error { - var terr error - inviteToken := getInviteToken(ctx) - if inviteToken != "" { - if user, terr = a.processInvite(ctx, tx, userData, instanceID, inviteToken, providerType); terr != nil { - return terr - } - } else { - aud := a.requestAud(ctx, r) - var emailData provider.Email - var identityData map[string]interface{} - if userData.Metadata != nil { - identityData, terr = userData.Metadata.ToMap() - if terr != nil { - return terr - } - } - - var identity *models.Identity - // check if identity exists - if identity, terr = models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType); terr != nil { - if models.IsNotFoundError(terr) { - user, emailData, terr = a.getUserByVerifiedEmail(tx, config, userData.Emails, instanceID, aud) - if terr != nil && !models.IsNotFoundError(terr) { - return internalServerError("Error checking for existing users").WithInternalError(terr) - } - if user != nil { - if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { - return terr - } - if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { - return terr - } - } else { - if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") - } - - // prefer primary email for new signups - emailData = userData.Emails[0] - for _, e := range userData.Emails { - if e.Primary { - emailData = e - break - } - } - - params := &SignupParams{ - Provider: providerType, - Email: emailData.Email, - Aud: aud, - Data: identityData, - } - - user, terr = a.signupNewUser(ctx, tx, params) - if terr != nil { - return terr - } - - if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { - return terr - } - } - } else { - return terr - } - } - - if identity != nil && user == nil { - // get user associated with identity - user, terr = models.FindUserByID(tx, identity.UserID) - if terr != nil { - return terr - } - identity.IdentityData = identityData - if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil { - return terr - } - // email & verified status might have changed if identity's email changed - emailData = provider.Email{ - Email: userData.Metadata.Email, - Verified: userData.Metadata.EmailVerified, - } - if terr = user.UpdateUserMetaData(tx, identityData); terr != nil { - return terr - } - if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { - return terr - } - } - - if user.IsBanned() { - return unauthorizedError("User is unauthorized") - } - - if !user.IsConfirmed() { - if !emailData.Verified && !config.Mailer.Autoconfirm { - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") - } - return internalServerError("Error sending confirmation mail").WithInternalError(terr) - } - // email must be verified to issue a token - return nil - } - - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - - // fall through to auto-confirm and issue token - if terr = user.Confirm(tx); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) - } - } else { - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.LoginAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, LoginEvent, user, instanceID, config); terr != nil { - return terr - } - } - } - - token, terr = a.issueRefreshToken(ctx, tx, user) - if terr != nil { - return oauthError("server_error", terr.Error()) - } - return nil - }) - if err != nil { - return err - } - - rurl := a.getExternalRedirectURL(r) - if token != nil { - q := url.Values{} - q.Set("provider_token", providerToken) - q.Set("access_token", token.Token) - q.Set("token_type", token.TokenType) - q.Set("expires_in", strconv.Itoa(token.ExpiresIn)) - q.Set("refresh_token", token.RefreshToken) - rurl += "#" + q.Encode() - - if err := a.setCookieTokens(config, token, false, w); err != nil { - return internalServerError("Failed to set JWT cookie. %s", err) - } - } else { - rurl = a.prepErrorRedirectURL(unauthorizedError("Unverified email with %v", providerType), r, rurl) - } - - http.Redirect(w, r, rurl, http.StatusFound) - return nil -} - -func (a *API) processInvite(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, instanceID uuid.UUID, inviteToken, providerType string) (*models.User, error) { - config := a.getConfig(ctx) - user, err := models.FindUserByConfirmationToken(tx, inviteToken) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) - } - return nil, internalServerError("Database error finding user").WithInternalError(err) - } - - var emailData *provider.Email - var emails []string - for _, e := range userData.Emails { - emails = append(emails, e.Email) - if user.GetEmail() == e.Email { - emailData = &e - break - } - } - - if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) - } - - var identityData map[string]interface{} - if userData.Metadata != nil { - identityData, err = userData.Metadata.ToMap() - if err != nil { - return nil, internalServerError("Error serialising user metadata").WithInternalError(err) - } - } - if _, err := a.createNewIdentity(tx, user, providerType, identityData); err != nil { - return nil, err - } - if err = user.UpdateAppMetaData(tx, map[string]interface{}{ - "provider": providerType, - }); err != nil { - return nil, err - } - if err = user.UpdateAppMetaDataProviders(tx); err != nil { - return nil, err - } - if err := user.UpdateUserMetaData(tx, identityData); err != nil { - return nil, internalServerError("Database error updating user").WithInternalError(err) - } - - if err := models.NewAuditLogEntry(tx, instanceID, user, models.InviteAcceptedAction, nil); err != nil { - return nil, err - } - if err := triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); err != nil { - return nil, err - } - - // confirm because they were able to respond to invite email - if err := user.Confirm(tx); err != nil { - return nil, err - } - return user, nil -} - -func (a *API) loadExternalState(ctx context.Context, state string) (context.Context, error) { - config := a.getConfig(ctx) - claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(config.JWT.Secret), nil - }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) - } - if claims.InviteToken != "" { - ctx = withInviteToken(ctx, claims.InviteToken) - } - if claims.Referrer != "" { - ctx = withExternalReferrer(ctx, claims.Referrer) - } - - ctx = withExternalProviderType(ctx, claims.Provider) - return withSignature(ctx, state), nil -} - -// Provider returns a Provider interface for the given name. -func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) { - config := a.getConfig(ctx) - name = strings.ToLower(name) - - switch name { - case "apple": - return provider.NewAppleProvider(config.External.Apple) - case "azure": - return provider.NewAzureProvider(config.External.Azure, scopes) - case "bitbucket": - return provider.NewBitbucketProvider(config.External.Bitbucket) - case "discord": - return provider.NewDiscordProvider(config.External.Discord, scopes) - case "github": - return provider.NewGithubProvider(config.External.Github, scopes) - case "gitlab": - return provider.NewGitlabProvider(config.External.Gitlab, scopes) - case "google": - return provider.NewGoogleProvider(config.External.Google, scopes) - case "linkedin": - return provider.NewLinkedinProvider(config.External.Linkedin, scopes) - case "facebook": - return provider.NewFacebookProvider(config.External.Facebook, scopes) - case "notion": - return provider.NewNotionProvider(config.External.Notion) - case "spotify": - return provider.NewSpotifyProvider(config.External.Spotify, scopes) - case "slack": - return provider.NewSlackProvider(config.External.Slack, scopes) - case "twitch": - return provider.NewTwitchProvider(config.External.Twitch, scopes) - case "twitter": - return provider.NewTwitterProvider(config.External.Twitter, scopes) - case "saml": - return provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx)) - default: - return nil, fmt.Errorf("Provider %s could not be found", name) - } -} - -func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := getLogEntry(r) - errorID := getRequestID(ctx) - err := handler(w, r) - if err != nil { - q := getErrorQueryString(err, errorID, log) - http.Redirect(w, r, a.getExternalRedirectURL(r)+"?"+q.Encode(), http.StatusFound) - } -} - -func getErrorQueryString(err error, errorID string, log logrus.FieldLogger) *url.Values { - q := url.Values{} - switch e := err.(type) { - case *HTTPError: - if str, ok := oauthErrorMap[e.Code]; ok { - q.Set("error", str) - } else { - q.Set("error", "server_error") - } - if e.Code >= http.StatusInternalServerError { - e.ErrorID = errorID - // this will get us the stack trace too - log.WithError(e.Cause()).Error(e.Error()) - } else { - log.WithError(e.Cause()).Info(e.Error()) - } - q.Set("error_description", e.Message) - case *OAuthError: - q.Set("error", e.Err) - q.Set("error_description", e.Description) - log.WithError(e.Cause()).Info(e.Error()) - case ErrorCause: - return getErrorQueryString(e.Cause(), errorID, log) - default: - q.Set("error", "server_error") - q.Set("error_description", err.Error()) - } - return &q -} - -func (a *API) getExternalRedirectURL(r *http.Request) string { - ctx := r.Context() - config := a.getConfig(ctx) - if config.External.RedirectURL != "" { - return config.External.RedirectURL - } - if er := getExternalReferrer(ctx); er != "" { - return er - } - return config.SiteURL -} - -func (a *API) createNewIdentity(conn *storage.Connection, user *models.User, providerType string, identityData map[string]interface{}) (*models.Identity, error) { - identity, err := models.NewIdentity(user, providerType, identityData) - if err != nil { - return nil, err - } - - err = conn.Transaction(func(tx *storage.Connection) error { - if terr := tx.Create(identity); terr != nil { - return internalServerError("Error creating identity").WithInternalError(terr) - } - return nil - }) - - if err != nil { - return nil, err - } - - return identity, nil -} - -// getUserByVerifiedEmail checks if one of the verified emails already belongs to a user -func (a *API) getUserByVerifiedEmail(tx *storage.Connection, config *conf.Configuration, emails []provider.Email, instanceID uuid.UUID, aud string) (*models.User, provider.Email, error) { - var user *models.User - var emailData provider.Email - var err error - - for _, e := range emails { - if e.Verified || config.Mailer.Autoconfirm { - user, err = models.FindUserByEmailAndAudience(tx, instanceID, e.Email, aud) - if err != nil && !models.IsNotFoundError(err) { - return user, emailData, err - } - if user != nil { - emailData = e - break - } - } - } - return user, emailData, err -} diff --git a/api/external_azure_test.go b/api/external_azure_test.go deleted file mode 100644 index 0626e9a71..000000000 --- a/api/external_azure_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package api - -import ( - "fmt" - "net/http" - "net/http/httptest" - "net/url" - - jwt "github.com/golang-jwt/jwt" -) - -const ( - azureUser string = `{"name":"Azure Test","email":"azure@example.com","sub":"azuretestid"}` - azureUserNoEmail string = `{"name":"Azure Test","sub":"azuretestid"}` -) - -func (ts *ExternalTestSuite) TestSignupExternalAzure() { - req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=azure", nil) - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - ts.Require().Equal(http.StatusFound, w.Code) - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - q := u.Query() - ts.Equal(ts.Config.External.Azure.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Azure.ClientID, q.Get("client_id")) - ts.Equal("code", q.Get("response_type")) - ts.Equal("openid", q.Get("scope")) - - claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(ts.Config.JWT.Secret), nil - }) - ts.Require().NoError(err) - - ts.Equal("azure", claims.Provider) - ts.Equal(ts.Config.SiteURL, claims.SiteURL) -} - -func AzureTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/common/oauth2/v2.0/token": - *tokenCount++ - ts.Equal(code, r.FormValue("code")) - ts.Equal("authorization_code", r.FormValue("grant_type")) - ts.Equal(ts.Config.External.Azure.RedirectURI, r.FormValue("redirect_uri")) - - w.Header().Add("Content-Type", "application/json") - fmt.Fprint(w, `{"access_token":"azure_token","expires_in":100000}`) - case "/oidc/userinfo": - *userCount++ - w.Header().Add("Content-Type", "application/json") - fmt.Fprint(w, user) - default: - w.WriteHeader(500) - ts.Fail("unknown azure oauth call %s", r.URL.Path) - } - })) - - ts.Config.External.Azure.URL = server.URL - - return server -} - -func (ts *ExternalTestSuite) TestSignupExternalAzure_AuthorizationCode() { - ts.Config.DisableSignup = false - ts.createUser("azuretestid", "azure@example.com", "Azure Test", "", "") - tokenCount, userCount := 0, 0 - code := "authcode" - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "") - - assertAuthorizationSuccess(ts, u, tokenCount, userCount, "azure@example.com", "Azure Test", "azuretestid", "") -} - -func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoUser() { - ts.Config.DisableSignup = true - tokenCount, userCount := 0, 0 - code := "authcode" - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "") - - assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "azure@example.com") -} - -func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoEmail() { - ts.Config.DisableSignup = true - tokenCount, userCount := 0, 0 - code := "authcode" - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUserNoEmail) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "") - - assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "azure@example.com") - -} - -func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupSuccessWithPrimaryEmail() { - ts.Config.DisableSignup = true - - ts.createUser("azuretestid", "azure@example.com", "Azure Test", "", "") - - tokenCount, userCount := 0, 0 - code := "authcode" - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "") - - assertAuthorizationSuccess(ts, u, tokenCount, userCount, "azure@example.com", "Azure Test", "azuretestid", "") -} - -func (ts *ExternalTestSuite) TestInviteTokenExternalAzureSuccessWhenMatchingToken() { - // name and avatar should be populated from Azure API - ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") - - tokenCount, userCount := 0, 0 - code := "authcode" - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "invite_token") - - assertAuthorizationSuccess(ts, u, tokenCount, userCount, "azure@example.com", "Azure Test", "azuretestid", "") -} - -func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenNoMatchingToken() { - tokenCount, userCount := 0, 0 - code := "authcode" - azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - w := performAuthorizationRequest(ts, "azure", "invite_token") - ts.Require().Equal(http.StatusNotFound, w.Code) -} - -func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenWrongToken() { - ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") - - tokenCount, userCount := 0, 0 - code := "authcode" - azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - w := performAuthorizationRequest(ts, "azure", "wrong_token") - ts.Require().Equal(http.StatusNotFound, w.Code) -} - -func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenEmailDoesntMatch() { - ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") - - tokenCount, userCount := 0, 0 - code := "authcode" - azureUser := `{"name":"Azure Test", "email":"other@example.com", "avatar":{"href":"http://example.com/avatar"}}` - server := AzureTestSignupSetup(ts, &tokenCount, &userCount, code, azureUser) - defer server.Close() - - u := performAuthorization(ts, "azure", code, "invite_token") - - assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") -} diff --git a/api/external_saml.go b/api/external_saml.go deleted file mode 100644 index fa74d1cd6..000000000 --- a/api/external_saml.go +++ /dev/null @@ -1,76 +0,0 @@ -package api - -import ( - "context" - "net/http" - - "github.com/netlify/gotrue/api/provider" -) - -func (a *API) loadSAMLState(w http.ResponseWriter, r *http.Request) (context.Context, error) { - state := r.FormValue("RelayState") - if state == "" { - return nil, badRequestError("SAML RelayState is missing") - } - - ctx := r.Context() - - return a.loadExternalState(ctx, state) -} - -func (a *API) samlCallback(ctx context.Context, r *http.Request) (*provider.UserProvidedData, error) { - config := a.getConfig(ctx) - - samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx)) - if err != nil { - return nil, badRequestError("Could not initialize SAML provider: %+v", err).WithInternalError(err) - } - - samlResponse := r.FormValue("SAMLResponse") - if samlResponse == "" { - return nil, badRequestError("SAML Response is missing") - } - - assertionInfo, err := samlProvider.ServiceProvider.RetrieveAssertionInfo(samlResponse) - if err != nil { - return nil, internalServerError("Parsing SAML assertion failed: %+v", err).WithInternalError(err) - } - - if assertionInfo.WarningInfo.InvalidTime { - return nil, forbiddenError("SAML response has invalid time") - } - - if assertionInfo.WarningInfo.NotInAudience { - return nil, forbiddenError("SAML response is not in audience") - } - - if assertionInfo == nil { - return nil, internalServerError("SAML Assertion is missing") - } - userData := &provider.UserProvidedData{ - Emails: []provider.Email{{ - Email: assertionInfo.NameID, - Verified: true, - }}, - Metadata: &provider.Claims{ - Subject: assertionInfo.NameID, - }, - } - return userData, nil -} - -// SAMLMetadata returns metadata information about the SAML provider -func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := getConfig(ctx) - - samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx)) - if err != nil { - return internalServerError("Could not create SAML Provider: %+v", err).WithInternalError(err) - } - - metadata, err := samlProvider.SPMetadata() - w.Header().Set("Content-Type", "application/xml") - w.Write(metadata) - return nil -} diff --git a/api/external_saml_test.go b/api/external_saml_test.go deleted file mode 100644 index 9ed225a50..000000000 --- a/api/external_saml_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package api - -import ( - "crypto/x509" - "encoding/base64" - "encoding/pem" - "encoding/xml" - "io" - "net/http" - "net/http/httptest" - "net/url" - "path/filepath" - "strings" - "testing" - "text/template" - "time" - - "github.com/beevik/etree" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/russellhaering/gosaml2/types" - dsig "github.com/russellhaering/goxmldsig" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type ExternalSamlTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - instanceID uuid.UUID -} - -func TestExternalSaml(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &ExternalSamlTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *ExternalSamlTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) -} - -func (ts *ExternalSamlTestSuite) docFromTemplate(path string, data interface{}) *etree.Document { - doc := etree.NewDocument() - - templ, err := template.ParseFiles(path) - ts.Require().NoError(err) - read, write := io.Pipe() - go func() { - defer write.Close() - err := templ.Execute(write, data) - ts.Require().NoError(err) - }() - _, err = doc.ReadFrom(read) - ts.Require().NoError(err) - - return doc -} - -func (ts *ExternalSamlTestSuite) setupSamlExampleResponse(keyStore dsig.X509KeyStore) string { - path := filepath.Join("testdata", "saml-response.xml") - type ResponseParams struct { - Now string - NotBefore string - NotAfter string - } - now := time.Now() - doc := ts.docFromTemplate(path, ResponseParams{ - Now: now.Format(time.RFC3339), - NotBefore: now.Add(-5 * time.Minute).Format(time.RFC3339), - NotAfter: now.Add(5 * time.Minute).Format(time.RFC3339), - }) - - // sign - resp := doc.SelectElement("Response") - ctx := dsig.NewDefaultSigningContext(keyStore) - sig, err := ctx.ConstructSignature(resp, true) - ts.Require().NoError(err, "Response signature failed") - respWithSig := resp.Copy() - var children []etree.Token - children = append(children, respWithSig.Child[0]) // issuer is always first - children = append(children, sig) // next is the signature - children = append(children, respWithSig.Child[1:]...) // then all other children - respWithSig.Child = children - doc.SetRoot(respWithSig) - - docRaw, err := doc.WriteToBytes() - ts.Require().NoError(err) - - return base64.StdEncoding.EncodeToString(docRaw) -} - -func (ts *ExternalSamlTestSuite) setupSamlExampleState() string { - req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=saml", nil) - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - ts.Require().Equal(http.StatusFound, w.Code) - - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - - urlBase, _ := url.Parse(u.String()) - urlBase.RawQuery = "" - ts.Equal(urlBase.String(), "https://idp/saml2test/redirect") - - q := u.Query() - state := q.Get("RelayState") - ts.Require().NotEmpty(state) - return state -} - -func (ts *ExternalSamlTestSuite) setupSamlMetadata() (*httptest.Server, dsig.X509KeyStore) { - idpKeyStore := dsig.RandomKeyStoreForTest() - _, idpCert, _ := idpKeyStore.GetKeyPair() - - path := filepath.Join("testdata", "saml-idp-metadata.xml") - type MetadataParams struct { - Cert string - } - doc := ts.docFromTemplate(path, MetadataParams{Cert: base64.StdEncoding.EncodeToString(idpCert)}) - metadata, err := doc.WriteToString() - ts.Require().NoError(err) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(200) - io.WriteString(w, metadata) - })) - return server, idpKeyStore -} - -func (ts *ExternalSamlTestSuite) setupSamlSPCert() (string, string) { - spKeyStore := dsig.RandomKeyStoreForTest() - key, cert, _ := spKeyStore.GetKeyPair() - keyBytes := pem.EncodeToMemory(&pem.Block{ - Type: "PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) - certBytes := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert, - }) - return string(keyBytes), string(certBytes) -} - -func (ts *ExternalSamlTestSuite) TestSignupExternalSaml_Callback() { - ts.SetupTest() - server, idpKeyStore := ts.setupSamlMetadata() - defer server.Close() - ts.Config.External.Saml.MetadataURL = server.URL - - key, cert := ts.setupSamlSPCert() - ts.Config.External.Saml.SigningKey = key - ts.Config.External.Saml.SigningCert = cert - - form := url.Values{} - form.Add("RelayState", ts.setupSamlExampleState()) - form.Add("SAMLResponse", ts.setupSamlExampleResponse(idpKeyStore)) - req := httptest.NewRequest(http.MethodPost, "http://localhost/saml/acs", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - ts.Require().Equal(http.StatusFound, w.Code) - - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - - v, err := url.ParseQuery(u.Fragment) - ts.Require().NoError(err) - ts.Empty(v.Get("error_description")) - ts.Empty(v.Get("error")) - - ts.NotEmpty(v.Get("access_token")) - ts.NotEmpty(v.Get("refresh_token")) - ts.NotEmpty(v.Get("expires_in")) - ts.Equal("bearer", v.Get("token_type")) - - // ensure user has been created - _, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "saml@example.com", ts.Config.JWT.Aud) - ts.Require().NoError(err) -} - -func (ts *ExternalSamlTestSuite) TestMetadata() { - server, _ := ts.setupSamlMetadata() - defer server.Close() - ts.Config.External.Saml.MetadataURL = server.URL - - key, cert := ts.setupSamlSPCert() - ts.Config.External.Saml.SigningKey = key - ts.Config.External.Saml.SigningCert = cert - - req := httptest.NewRequest(http.MethodGet, "http://localhost/saml/metadata", nil) - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - ts.Require().Equal(http.StatusOK, w.Code) - - md := &types.EntityDescriptor{} - err := xml.NewDecoder(w.Body).Decode(md) - ts.Require().NoError(err) - - ts.Equal("http://localhost/saml", md.EntityID) - for _, acs := range md.SPSSODescriptor.AssertionConsumerServices { - if acs.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" { - ts.Equal("http://localhost/saml/acs", acs.Location) - break - } - } -} diff --git a/api/external_test.go b/api/external_test.go deleted file mode 100644 index 9aee2c30c..000000000 --- a/api/external_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package api - -import ( - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type ExternalTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - instanceID uuid.UUID -} - -func TestExternal(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &ExternalTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *ExternalTestSuite) SetupTest() { - ts.Config.DisableSignup = false - ts.Config.Mailer.Autoconfirm = false - - models.TruncateAll(ts.API.db) -} - -func (ts *ExternalTestSuite) createUser(providerId string, email string, name string, avatar string, confirmationToken string) (*models.User, error) { - // Cleanup existing user, if they already exist - if u, _ := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, email, ts.Config.JWT.Aud); u != nil { - require.NoError(ts.T(), ts.API.db.Destroy(u), "Error deleting user") - } - - u, err := models.NewUser(ts.instanceID, email, "test", ts.Config.JWT.Aud, map[string]interface{}{"provider_id": providerId, "full_name": name, "avatar_url": avatar}) - - if confirmationToken != "" { - u.ConfirmationToken = confirmationToken - } - ts.Require().NoError(err, "Error making new user") - ts.Require().NoError(ts.API.db.Create(u), "Error creating user") - - return u, err -} - -func performAuthorizationRequest(ts *ExternalTestSuite, provider string, inviteToken string) *httptest.ResponseRecorder { - authorizeURL := "http://localhost/authorize?provider=" + provider - if inviteToken != "" { - authorizeURL = authorizeURL + "&invite_token=" + inviteToken - } - - req := httptest.NewRequest(http.MethodGet, authorizeURL, nil) - req.Header.Set("Referer", "https://example.netlify.com/admin") - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - return w -} - -func performAuthorization(ts *ExternalTestSuite, provider string, code string, inviteToken string) *url.URL { - w := performAuthorizationRequest(ts, provider, inviteToken) - ts.Require().Equal(http.StatusFound, w.Code) - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - q := u.Query() - state := q.Get("state") - - // auth server callback - testURL, err := url.Parse("http://localhost/callback") - ts.Require().NoError(err) - v := testURL.Query() - v.Set("code", code) - v.Set("state", state) - testURL.RawQuery = v.Encode() - req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - ts.Require().Equal(http.StatusFound, w.Code) - u, err = url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - ts.Require().Equal("/admin", u.Path) - - return u -} - -func assertAuthorizationSuccess(ts *ExternalTestSuite, u *url.URL, tokenCount int, userCount int, email string, name string, providerId string, avatar string) { - // ensure redirect has #access_token=... - v, err := url.ParseQuery(u.RawQuery) - ts.Require().NoError(err) - ts.Require().Empty(v.Get("error_description")) - ts.Require().Empty(v.Get("error")) - - v, err = url.ParseQuery(u.Fragment) - ts.Require().NoError(err) - ts.NotEmpty(v.Get("access_token")) - ts.NotEmpty(v.Get("refresh_token")) - ts.NotEmpty(v.Get("expires_in")) - ts.Equal("bearer", v.Get("token_type")) - - ts.Equal(1, tokenCount) - ts.Equal(1, userCount) - - // ensure user has been created with metadata - user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, email, ts.Config.JWT.Aud) - ts.Require().NoError(err) - ts.Equal(providerId, user.UserMetaData["provider_id"]) - ts.Equal(name, user.UserMetaData["full_name"]) - ts.Equal(avatar, user.UserMetaData["avatar_url"]) -} - -func assertAuthorizationFailure(ts *ExternalTestSuite, u *url.URL, errorDescription string, errorType string, email string) { - // ensure new sign ups error - v, err := url.ParseQuery(u.RawQuery) - ts.Require().NoError(err) - ts.Require().Equal(errorDescription, v.Get("error_description")) - ts.Require().Equal(errorType, v.Get("error")) - - v, err = url.ParseQuery(u.Fragment) - ts.Require().NoError(err) - ts.Empty(v.Get("access_token")) - ts.Empty(v.Get("refresh_token")) - ts.Empty(v.Get("expires_in")) - ts.Empty(v.Get("token_type")) - - // ensure user is nil - user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, email, ts.Config.JWT.Aud) - ts.Require().Error(err, "User not found") - ts.Require().Nil(user) -} - -// TestSignupExternalUnsupported tests API /authorize for an unsupported external provider -func (ts *ExternalTestSuite) TestSignupExternalUnsupported() { - req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=external", nil) - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - ts.Equal(w.Code, http.StatusBadRequest) -} diff --git a/api/helpers.go b/api/helpers.go deleted file mode 100644 index c8cfc2f35..000000000 --- a/api/helpers.go +++ /dev/null @@ -1,278 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "net/http/httptrace" - "net/url" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -func addRequestID(globalConfig *conf.GlobalConfiguration) middlewareHandler { - return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { - id := "" - if globalConfig.API.RequestIDHeader != "" { - id = r.Header.Get(globalConfig.API.RequestIDHeader) - } - if id == "" { - uid, err := uuid.NewV4() - if err != nil { - return nil, err - } - id = uid.String() - } - - ctx := r.Context() - ctx = withRequestID(ctx, id) - return ctx, nil - } -} - -func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { - w.Header().Set("Content-Type", "application/json") - b, err := json.Marshal(obj) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) - } - w.WriteHeader(status) - _, err = w.Write(b) - return err -} - -func getUserFromClaims(ctx context.Context, conn *storage.Connection) (*models.User, error) { - claims := getClaims(ctx) - if claims == nil { - return nil, errors.New("Invalid token") - } - - if claims.Subject == "" { - return nil, errors.New("Invalid claim: id") - } - - // System User - instanceID := getInstanceID(ctx) - - if claims.Subject == models.SystemUserUUID.String() || claims.Subject == models.SystemUserID { - return models.NewSystemUser(instanceID, claims.Audience), nil - } - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return nil, errors.New("Invalid user ID") - } - return models.FindUserByInstanceIDAndID(conn, instanceID, userID) -} - -func (a *API) isAdmin(ctx context.Context, u *models.User, aud string) bool { - config := a.getConfig(ctx) - if aud == "" { - aud = config.JWT.Aud - } - return u.IsSuperAdmin || (aud == u.Aud && u.HasRole(config.JWT.AdminGroupName)) -} - -func (a *API) requestAud(ctx context.Context, r *http.Request) string { - config := a.getConfig(ctx) - // First check for an audience in the header - if aud := r.Header.Get(audHeaderName); aud != "" { - return aud - } - - // Then check the token - claims := getClaims(ctx) - if claims != nil && claims.Audience != "" { - return claims.Audience - } - - // Finally, return the default of none of the above methods are successful - return config.JWT.Aud -} - -// tries extract redirect url from header or from query params -func getRedirectTo(r *http.Request) (reqref string) { - reqref = r.Header.Get("redirect_to") - if reqref != "" { - return - } - - if err := r.ParseForm(); err == nil { - reqref = r.Form.Get("redirect_to") - } - - return -} - -func isRedirectURLValid(config *conf.Configuration, redirectURL string) bool { - if redirectURL == "" { - return false - } - - base, berr := url.Parse(config.SiteURL) - refurl, rerr := url.Parse(redirectURL) - - // As long as the referrer came from the site, we will redirect back there - if berr == nil && rerr == nil && base.Hostname() == refurl.Hostname() { - return true - } - - // For case when user came from mobile app or other permitted resource - redirect back - for _, uri := range config.URIAllowList { - if redirectURL == uri { - return true - } - } - - return false -} - -func (a *API) getReferrer(r *http.Request) string { - ctx := r.Context() - config := a.getConfig(ctx) - - // try get redirect url from query or post data first - reqref := getRedirectTo(r) - if isRedirectURLValid(config, reqref) { - return reqref - } - - // instead try referrer header value - reqref = r.Referer() - if isRedirectURLValid(config, reqref) { - return reqref - } - - return config.SiteURL -} - -// getRedirectURLOrReferrer ensures any redirect URL is from a safe origin -func (a *API) getRedirectURLOrReferrer(r *http.Request, reqref string) string { - ctx := r.Context() - config := a.getConfig(ctx) - - // if redirect url fails - try fill by extra variant - if isRedirectURLValid(config, reqref) { - return reqref - } - - return a.getReferrer(r) -} - -var privateIPBlocks []*net.IPNet - -func init() { - for _, cidr := range []string{ - "127.0.0.0/8", // IPv4 loopback - "10.0.0.0/8", // RFC1918 - "100.64.0.0/10", // RFC6598 - "172.16.0.0/12", // RFC1918 - "192.0.0.0/24", // RFC6890 - "192.168.0.0/16", // RFC1918 - "169.254.0.0/16", // RFC3927 - "::1/128", // IPv6 loopback - "fe80::/10", // IPv6 link-local - "fc00::/7", // IPv6 unique local addr - } { - _, block, _ := net.ParseCIDR(cidr) - privateIPBlocks = append(privateIPBlocks, block) - } -} - -func isPrivateIP(ip net.IP) bool { - for _, block := range privateIPBlocks { - if block.Contains(ip) { - return true - } - } - return false -} - -func removeLocalhostFromPrivateIPBlock() *net.IPNet { - _, localhost, _ := net.ParseCIDR("127.0.0.0/8") - - var localhostIndex int - for i := 0; i < len(privateIPBlocks); i++ { - if privateIPBlocks[i] == localhost { - localhostIndex = i - } - } - privateIPBlocks = append(privateIPBlocks[:localhostIndex], privateIPBlocks[localhostIndex+1:]...) - - return localhost -} - -func unshiftPrivateIPBlock(address *net.IPNet) { - privateIPBlocks = append([]*net.IPNet{address}, privateIPBlocks...) -} - -type noLocalTransport struct { - inner http.RoundTripper - errlog logrus.FieldLogger -} - -func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error) { - ctx, cancel := context.WithCancel(req.Context()) - - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - ConnectStart: func(network, addr string) { - fmt.Printf("Checking network %v\n", addr) - host, _, err := net.SplitHostPort(addr) - if err != nil { - cancel() - fmt.Printf("Canceleing dur to error in addr parsing %v", err) - return - } - ip := net.ParseIP(host) - if ip == nil { - cancel() - fmt.Printf("Canceleing dur to error in ip parsing %v", host) - return - } - - if isPrivateIP(ip) { - cancel() - fmt.Println("Canceleing dur to private ip range") - return - } - - }, - }) - - req = req.WithContext(ctx) - return no.inner.RoundTrip(req) -} - -func SafeRoundtripper(trans http.RoundTripper, log logrus.FieldLogger) http.RoundTripper { - if trans == nil { - trans = http.DefaultTransport - } - - ret := &noLocalTransport{ - inner: trans, - errlog: log.WithField("transport", "local_blocker"), - } - - return ret -} - -func SafeHTTPClient(client *http.Client, log logrus.FieldLogger) *http.Client { - client.Transport = SafeRoundtripper(client.Transport, log) - - return client -} - -func isStringInSlice(checkValue string, list []string) bool { - for _, val := range list { - if val == checkValue { - return true - } - } - return false -} diff --git a/api/hook_test.go b/api/hook_test.go deleted file mode 100644 index f60a0218c..000000000 --- a/api/hook_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSignupHookSendInstanceID(t *testing.T) { - globalConfig, err := conf.LoadGlobal(apiTestConfig) - require.NoError(t, err) - - conn, err := test.SetupDBConnection(globalConfig) - require.NoError(t, err) - - iid := uuid.Must(uuid.NewV4()) - user, err := models.NewUser(iid, "test@truth.com", "thisisapassword", "", nil) - require.NoError(t, err) - - var callCount int - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - defer squash(r.Body.Close) - raw, err := ioutil.ReadAll(r.Body) - require.NoError(t, err) - - data := map[string]interface{}{} - require.NoError(t, json.Unmarshal(raw, &data)) - - assert.Len(t, data, 3) - assert.Equal(t, iid.String(), data["instance_id"]) - w.WriteHeader(http.StatusOK) - })) - defer svr.Close() - - // Allowing connection to localhost for the tests only - localhost := removeLocalhostFromPrivateIPBlock() - defer unshiftPrivateIPBlock(localhost) - - config := &conf.Configuration{ - Webhook: conf.WebhookConfig{ - URL: svr.URL, - Events: []string{SignupEvent}, - }, - } - - require.NoError(t, triggerEventHooks(context.Background(), conn, SignupEvent, user, iid, config)) - - assert.Equal(t, 1, callCount) -} - -func TestSignupHookFromClaims(t *testing.T) { - globalConfig, err := conf.LoadGlobal(apiTestConfig) - require.NoError(t, err) - - conn, err := test.SetupDBConnection(globalConfig) - require.NoError(t, err) - - iid := uuid.Must(uuid.NewV4()) - user, err := models.NewUser(iid, "test@truth.com", "thisisapassword", "", nil) - require.NoError(t, err) - - var callCount int - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - defer squash(r.Body.Close) - raw, err := ioutil.ReadAll(r.Body) - require.NoError(t, err) - - data := map[string]interface{}{} - require.NoError(t, json.Unmarshal(raw, &data)) - - assert.Len(t, data, 3) - assert.Equal(t, iid.String(), data["instance_id"]) - w.WriteHeader(http.StatusOK) - })) - defer svr.Close() - - // Allowing connection to localhost for the tests only - localhost := removeLocalhostFromPrivateIPBlock() - defer unshiftPrivateIPBlock(localhost) - - config := &conf.Configuration{ - Webhook: conf.WebhookConfig{ - Events: []string{"signup"}, - }, - } - - ctx := context.Background() - ctx = withFunctionHooks(ctx, map[string][]string{ - "signup": []string{svr.URL}, - }) - - require.NoError(t, triggerEventHooks(ctx, conn, SignupEvent, user, iid, config)) - - assert.Equal(t, 1, callCount) -} - -func TestHookRetry(t *testing.T) { - var callCount int - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - assert.EqualValues(t, 0, r.ContentLength) - if callCount == 3 { - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusBadRequest) - } - })) - defer svr.Close() - // Allowing connection to localhost for the tests only - localhost := removeLocalhostFromPrivateIPBlock() - defer unshiftPrivateIPBlock(localhost) - - config := &conf.WebhookConfig{ - URL: svr.URL, - Retries: 3, - } - w := Webhook{ - WebhookConfig: config, - } - b, err := w.trigger() - defer func() { - if b != nil { - b.Close() - } - }() - require.NoError(t, err) - - assert.Equal(t, 3, callCount) -} - -func TestHookTimeout(t *testing.T) { - var callCount int - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - <-time.After(2 * time.Second) - })) - defer svr.Close() - - // Allowing connection to localhost for the tests only - localhost := removeLocalhostFromPrivateIPBlock() - defer unshiftPrivateIPBlock(localhost) - - config := &conf.WebhookConfig{ - URL: svr.URL, - Retries: 3, - TimeoutSec: 1, - } - w := Webhook{ - WebhookConfig: config, - } - _, err := w.trigger() - require.Error(t, err) - herr, ok := err.(*HTTPError) - require.True(t, ok) - assert.Equal(t, http.StatusGatewayTimeout, herr.Code) - - assert.Equal(t, 3, callCount) -} - -func TestHookNoServer(t *testing.T) { - config := &conf.WebhookConfig{ - URL: "http://somewhere.something.com", - Retries: 1, - TimeoutSec: 1, - } - w := Webhook{ - WebhookConfig: config, - } - _, err := w.trigger() - require.Error(t, err) - herr, ok := err.(*HTTPError) - require.True(t, ok) - assert.Equal(t, http.StatusBadGateway, herr.Code) -} - -func squash(f func() error) { _ = f } diff --git a/api/hooks.go b/api/hooks.go deleted file mode 100644 index 28f8ccc28..000000000 --- a/api/hooks.go +++ /dev/null @@ -1,292 +0,0 @@ -package api - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "io" - "net" - "net/http" - "net/http/httptrace" - "net/url" - "time" - - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -type HookEvent string - -const ( - headerHookSignature = "x-webhook-signature" - defaultHookRetries = 3 - gotrueIssuer = "gotrue" - ValidateEvent = "validate" - SignupEvent = "signup" - EmailChangeEvent = "email_change" - LoginEvent = "login" -) - -var defaultTimeout = time.Second * 5 - -type webhookClaims struct { - jwt.StandardClaims - SHA256 string `json:"sha256"` -} - -type Webhook struct { - *conf.WebhookConfig - - instanceID uuid.UUID - jwtSecret string - claims jwt.Claims - payload []byte - headers map[string]string -} - -type WebhookResponse struct { - AppMetaData map[string]interface{} `json:"app_metadata,omitempty"` - UserMetaData map[string]interface{} `json:"user_metadata,omitempty"` -} - -func (w *Webhook) trigger() (io.ReadCloser, error) { - timeout := defaultTimeout - if w.TimeoutSec > 0 { - timeout = time.Duration(w.TimeoutSec) * time.Second - } - - if w.Retries == 0 { - w.Retries = defaultHookRetries - } - - hooklog := logrus.WithFields(logrus.Fields{ - "component": "webhook", - "url": w.URL, - "signed": w.jwtSecret != "", - "instance_id": w.instanceID, - }) - client := http.Client{ - Timeout: timeout, - } - client.Transport = SafeRoundtripper(client.Transport, hooklog) - - for i := 0; i < w.Retries; i++ { - hooklog = hooklog.WithField("attempt", i+1) - hooklog.Info("Starting to perform signup hook request") - - req, err := http.NewRequest(http.MethodPost, w.URL, bytes.NewBuffer(w.payload)) - if err != nil { - return nil, internalServerError("Failed to make request object").WithInternalError(err) - } - req.Header.Set("Content-Type", "application/json") - watcher, req := watchForConnection(req) - - if w.jwtSecret != "" { - header, jwtErr := w.generateSignature() - if jwtErr != nil { - return nil, jwtErr - } - req.Header.Set(headerHookSignature, header) - } - - start := time.Now() - rsp, err := client.Do(req) - if err != nil { - if terr, ok := err.(net.Error); ok && terr.Timeout() { - // timed out - try again? - if i == w.Retries-1 { - closeBody(rsp) - return nil, httpError(http.StatusGatewayTimeout, "Failed to perform webhook in time frame (%v seconds)", timeout.Seconds()) - } - hooklog.Info("Request timed out") - continue - } else if watcher.gotConn { - closeBody(rsp) - return nil, internalServerError("Failed to trigger webhook to %s", w.URL).WithInternalError(err) - } else { - closeBody(rsp) - return nil, httpError(http.StatusBadGateway, "Failed to connect to %s", w.URL) - } - } - dur := time.Since(start) - rspLog := hooklog.WithFields(logrus.Fields{ - "status_code": rsp.StatusCode, - "dur": dur.Nanoseconds(), - }) - switch rsp.StatusCode { - case http.StatusOK, http.StatusNoContent, http.StatusAccepted: - rspLog.Infof("Finished processing webhook in %s", dur) - var body io.ReadCloser - if rsp.ContentLength > 0 { - body = rsp.Body - } - return body, nil - default: - rspLog.Infof("Bad response for webhook %d in %s", rsp.StatusCode, dur) - } - } - - hooklog.Infof("Failed to process webhook for %s after %d attempts", w.URL, w.Retries) - return nil, unprocessableEntityError("Failed to handle signup webhook") -} - -func (w *Webhook) generateSignature() (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, w.claims) - tokenString, err := token.SignedString([]byte(w.jwtSecret)) - if err != nil { - return "", internalServerError("Failed build signing string").WithInternalError(err) - } - return tokenString, nil -} - -func closeBody(rsp *http.Response) { - if rsp != nil && rsp.Body != nil { - rsp.Body.Close() - } -} - -func triggerEventHooks(ctx context.Context, conn *storage.Connection, event HookEvent, user *models.User, instanceID uuid.UUID, config *conf.Configuration) error { - if config.Webhook.URL != "" { - hookURL, err := url.Parse(config.Webhook.URL) - if err != nil { - return errors.Wrapf(err, "Failed to parse Webhook URL") - } - if !config.Webhook.HasEvent(string(event)) { - return nil - } - return triggerHook(ctx, hookURL, config.Webhook.Secret, conn, event, user, instanceID, config) - } - - fun := getFunctionHooks(ctx) - if fun == nil { - return nil - } - - for _, eventHookURL := range fun[string(event)] { - hookURL, err := url.Parse(eventHookURL) - if err != nil { - return errors.Wrapf(err, "Failed to parse Event Function Hook URL") - } - err = triggerHook(ctx, hookURL, config.JWT.Secret, conn, event, user, instanceID, config) - if err != nil { - return err - } - } - return nil -} - -func triggerHook(ctx context.Context, hookURL *url.URL, secret string, conn *storage.Connection, event HookEvent, user *models.User, instanceID uuid.UUID, config *conf.Configuration) error { - if !hookURL.IsAbs() { - siteURL, err := url.Parse(config.SiteURL) - if err != nil { - return errors.Wrapf(err, "Failed to parse Site URL") - } - hookURL.Scheme = siteURL.Scheme - hookURL.Host = siteURL.Host - hookURL.User = siteURL.User - } - - payload := struct { - Event HookEvent `json:"event"` - InstanceID uuid.UUID `json:"instance_id,omitempty"` - User *models.User `json:"user"` - }{ - Event: event, - InstanceID: instanceID, - User: user, - } - data, err := json.Marshal(&payload) - if err != nil { - return internalServerError("Failed to serialize the data for signup webhook").WithInternalError(err) - } - - sha, err := checksum(data) - if err != nil { - return internalServerError("Failed to checksum the data for signup webhook").WithInternalError(err) - } - - claims := webhookClaims{ - StandardClaims: jwt.StandardClaims{ - IssuedAt: time.Now().Unix(), - Subject: instanceID.String(), - Issuer: gotrueIssuer, - }, - SHA256: sha, - } - - w := Webhook{ - WebhookConfig: &config.Webhook, - jwtSecret: secret, - instanceID: instanceID, - claims: claims, - payload: data, - } - - w.URL = hookURL.String() - - body, err := w.trigger() - defer func() { - if body != nil { - body.Close() - } - }() - if err == nil && body != nil { - webhookRsp := &WebhookResponse{} - decoder := json.NewDecoder(body) - if err = decoder.Decode(webhookRsp); err != nil { - return internalServerError("Webhook returned malformed JSON: %v", err).WithInternalError(err) - } - return conn.Transaction(func(tx *storage.Connection) error { - if webhookRsp.UserMetaData != nil { - user.UserMetaData = nil - if terr := user.UpdateUserMetaData(tx, webhookRsp.UserMetaData); terr != nil { - return terr - } - } - if webhookRsp.AppMetaData != nil { - user.AppMetaData = nil - if terr := user.UpdateAppMetaData(tx, webhookRsp.AppMetaData); terr != nil { - return terr - } - } - return nil - }) - } - return err -} - -func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) { - w := new(connectionWatcher) - t := &httptrace.ClientTrace{ - GotConn: w.GotConn, - } - - req = req.WithContext(httptrace.WithClientTrace(req.Context(), t)) - return w, req -} - -func checksum(data []byte) (string, error) { - sha := sha256.New() - _, err := sha.Write(data) - if err != nil { - return "", err - } - - return hex.EncodeToString(sha.Sum(nil)), nil -} - -type connectionWatcher struct { - gotConn bool -} - -func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { - c.gotConn = true -} diff --git a/api/instance.go b/api/instance.go deleted file mode 100644 index f15d48582..000000000 --- a/api/instance.go +++ /dev/null @@ -1,130 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "net/http" - - "github.com/go-chi/chi" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/pkg/errors" -) - -func (a *API) loadInstance(w http.ResponseWriter, r *http.Request) (context.Context, error) { - instanceID, err := uuid.FromString(chi.URLParam(r, "instance_id")) - if err != nil { - return nil, badRequestError("Invalid instance ID") - } - logEntrySetField(r, "instance_id", instanceID) - - i, err := models.GetInstance(a.db, instanceID) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError("Instance not found") - } - return nil, internalServerError("Database error loading instance").WithInternalError(err) - } - - return withInstance(r.Context(), i), nil -} - -func (a *API) GetAppManifest(w http.ResponseWriter, r *http.Request) error { - // TODO update to real manifest - return sendJSON(w, http.StatusOK, map[string]string{ - "version": a.version, - "name": "GoTrue", - "description": "GoTrue is a user registration and authentication API", - }) -} - -type InstanceRequestParams struct { - UUID uuid.UUID `json:"uuid"` - BaseConfig *conf.Configuration `json:"config"` -} - -type InstanceResponse struct { - models.Instance - Endpoint string `json:"endpoint"` - State string `json:"state"` -} - -func (a *API) CreateInstance(w http.ResponseWriter, r *http.Request) error { - params := InstanceRequestParams{} - if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { - return badRequestError("Error decoding params: %v", err) - } - - _, err := models.GetInstanceByUUID(a.db, params.UUID) - if err != nil { - if !models.IsNotFoundError(err) { - return internalServerError("Database error looking up instance").WithInternalError(err) - } - } else { - return badRequestError("An instance with that UUID already exists") - } - - id, err := uuid.NewV4() - if err != nil { - return errors.Wrap(err, "Error generating id") - } - - i := models.Instance{ - ID: id, - UUID: params.UUID, - BaseConfig: params.BaseConfig, - } - if err = a.db.Create(&i); err != nil { - return internalServerError("Database error creating instance").WithInternalError(err) - } - - // hide pass in response - if i.BaseConfig != nil { - i.BaseConfig.SMTP.Pass = "" - } - - resp := InstanceResponse{ - Instance: i, - Endpoint: a.config.API.Endpoint, - State: "active", - } - return sendJSON(w, http.StatusCreated, resp) -} - -func (a *API) GetInstance(w http.ResponseWriter, r *http.Request) error { - i := getInstance(r.Context()) - if i.BaseConfig != nil { - i.BaseConfig.SMTP.Pass = "" - } - return sendJSON(w, http.StatusOK, i) -} - -func (a *API) UpdateInstance(w http.ResponseWriter, r *http.Request) error { - i := getInstance(r.Context()) - - params := InstanceRequestParams{BaseConfig: i.BaseConfig} - if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { - return badRequestError("Error decoding params: %v", err) - } - - if err := i.UpdateConfig(a.db, params.BaseConfig); err != nil { - return internalServerError("Database error updating instance").WithInternalError(err) - } - - // Hide SMTP credential from response - if i.BaseConfig != nil { - i.BaseConfig.SMTP.Pass = "" - } - return sendJSON(w, http.StatusOK, i) -} - -func (a *API) DeleteInstance(w http.ResponseWriter, r *http.Request) error { - i := getInstance(r.Context()) - if err := models.DeleteInstance(a.db, i); err != nil { - return internalServerError("Database error deleting instance").WithInternalError(err) - } - - w.WriteHeader(http.StatusNoContent) - return nil -} diff --git a/api/instance_test.go b/api/instance_test.go deleted file mode 100644 index 17d720d2a..000000000 --- a/api/instance_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gofrs/uuid" - - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -var testUUID = uuid.Must(uuid.FromString("11111111-1111-1111-1111-111111111111")) - -const operatorToken = "operatorToken" - -type InstanceTestSuite struct { - suite.Suite - API *API -} - -func TestInstance(t *testing.T) { - api, _, err := setupAPIForMultiinstanceTest() - require.NoError(t, err) - - api.config.OperatorToken = operatorToken - - ts := &InstanceTestSuite{ - API: api, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *InstanceTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) -} - -func (ts *InstanceTestSuite) TestCreate() { - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "uuid": testUUID, - "site_url": "https://example.netlify.com", - "config": map[string]interface{}{ - "jwt": map[string]interface{}{ - "secret": "testsecret", - }, - }, - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "/instances", &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusCreated, w.Code) - resp := models.Instance{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&resp)) - assert.NotNil(ts.T(), resp.BaseConfig) - - i, err := models.GetInstanceByUUID(ts.API.db, testUUID) - require.NoError(ts.T(), err) - assert.NotNil(ts.T(), i.BaseConfig) -} - -func (ts *InstanceTestSuite) TestGet() { - instanceID := uuid.Must(uuid.NewV4()) - err := ts.API.db.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: &conf.Configuration{ - JWT: conf.JWTConfiguration{ - Secret: "testsecret", - }, - }, - }) - require.NoError(ts.T(), err) - - req := httptest.NewRequest(http.MethodGet, "/instances/"+instanceID.String(), nil) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, http.StatusOK) - resp := models.Instance{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&resp)) -} - -func (ts *InstanceTestSuite) TestUpdate() { - instanceID := uuid.Must(uuid.NewV4()) - err := ts.API.db.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: &conf.Configuration{ - JWT: conf.JWTConfiguration{ - Secret: "testsecret", - }, - }, - }) - require.NoError(ts.T(), err) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "config": &conf.Configuration{ - JWT: conf.JWTConfiguration{ - Secret: "testsecret", - }, - SiteURL: "https://test.mysite.com", - }, - })) - - req := httptest.NewRequest(http.MethodPut, "/instances/"+instanceID.String(), &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, http.StatusOK) - - i, err := models.GetInstanceByUUID(ts.API.db, testUUID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), i.BaseConfig.JWT.Secret, "testsecret") - require.Equal(ts.T(), i.BaseConfig.SiteURL, "https://test.mysite.com") -} - -func (ts *InstanceTestSuite) TestUpdate_DisableEmail() { - instanceID := uuid.Must(uuid.NewV4()) - err := ts.API.db.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: &conf.Configuration{ - External: conf.ProviderConfiguration{ - Email: conf.EmailProviderConfiguration{ - Enabled: true, - }, - }, - }, - }) - require.NoError(ts.T(), err) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "config": &conf.Configuration{ - External: conf.ProviderConfiguration{ - Email: conf.EmailProviderConfiguration{ - Enabled: false, - }, - }, - }, - })) - - req := httptest.NewRequest(http.MethodPut, "/instances/"+instanceID.String(), &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, http.StatusOK) - - i, err := models.GetInstanceByUUID(ts.API.db, testUUID) - require.NoError(ts.T(), err) - require.False(ts.T(), i.BaseConfig.External.Email.Enabled) -} - -func (ts *InstanceTestSuite) TestUpdate_PreserveSMTPConfig() { - instanceID := uuid.Must(uuid.NewV4()) - err := ts.API.db.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: &conf.Configuration{ - SMTP: conf.SMTPConfiguration{ - Host: "foo.com", - User: "Admin", - Pass: "password123", - }, - }, - }) - require.NoError(ts.T(), err) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "config": &conf.Configuration{ - Mailer: conf.MailerConfiguration{ - Subjects: conf.EmailContentConfiguration{Invite: "foo"}, - Templates: conf.EmailContentConfiguration{Invite: "bar"}, - }, - }, - })) - - req := httptest.NewRequest(http.MethodPut, "/instances/"+instanceID.String(), &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, http.StatusOK) - - i, err := models.GetInstanceByUUID(ts.API.db, testUUID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), "password123", i.BaseConfig.SMTP.Pass) -} - -func (ts *InstanceTestSuite) TestUpdate_ClearPassword() { - instanceID := uuid.Must(uuid.NewV4()) - err := ts.API.db.Create(&models.Instance{ - ID: instanceID, - UUID: testUUID, - BaseConfig: &conf.Configuration{ - SMTP: conf.SMTPConfiguration{ - Host: "foo.com", - User: "Admin", - Pass: "password123", - }, - }, - }) - require.NoError(ts.T(), err) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "config": map[string]interface{}{ - "smtp": map[string]interface{}{ - "pass": "", - }, - }, - })) - ts.T().Log(buffer.String()) - - req := httptest.NewRequest(http.MethodPut, "/instances/"+instanceID.String(), &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+operatorToken) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, http.StatusOK) - - i, err := models.GetInstanceByUUID(ts.API.db, testUUID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), "", i.BaseConfig.SMTP.Pass) -} diff --git a/api/invite.go b/api/invite.go deleted file mode 100644 index 6c13861c8..000000000 --- a/api/invite.go +++ /dev/null @@ -1,77 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// InviteParams are the parameters the Signup endpoint accepts -type InviteParams struct { - Email string `json:"email"` - Data map[string]interface{} `json:"data"` -} - -// Invite is the endpoint for inviting a new user -func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - instanceID := getInstanceID(ctx) - adminUser := getAdminUser(ctx) - params := &InviteParams{} - - jsonDecoder := json.NewDecoder(r.Body) - err := jsonDecoder.Decode(params) - if err != nil { - return badRequestError("Could not read Invite params: %v", err) - } - - if err := a.validateEmail(ctx, params.Email); err != nil { - return err - } - - aud := a.requestAud(ctx, r) - user, err := models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - if err != nil && !models.IsNotFoundError(err) { - return internalServerError("Database error finding user").WithInternalError(err) - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if user != nil { - if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) - } - } else { - signupParams := SignupParams{ - Email: params.Email, - Data: params.Data, - Aud: aud, - Provider: "email", - } - user, err = a.signupNewUser(ctx, tx, &signupParams) - if err != nil { - return err - } - } - - if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserInvitedAction, map[string]interface{}{ - "user_id": user.ID, - "user_email": user.Email, - }); terr != nil { - return terr - } - - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - if err := sendInvite(tx, user, mailer, referrer); err != nil { - return internalServerError("Error inviting user").WithInternalError(err) - } - return nil - }) - if err != nil { - return err - } - - return sendJSON(w, http.StatusOK, user) -} diff --git a/api/log.go b/api/log.go deleted file mode 100644 index c51af9061..000000000 --- a/api/log.go +++ /dev/null @@ -1,81 +0,0 @@ -package api - -import ( - "fmt" - "net/http" - "time" - - chimiddleware "github.com/go-chi/chi/middleware" - "github.com/sirupsen/logrus" -) - -func newStructuredLogger(logger *logrus.Logger) func(next http.Handler) http.Handler { - return chimiddleware.RequestLogger(&structuredLogger{logger}) -} - -type structuredLogger struct { - Logger *logrus.Logger -} - -func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { - entry := &structuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)} - logFields := logrus.Fields{ - "component": "api", - "method": r.Method, - "path": r.URL.Path, - "remote_addr": r.RemoteAddr, - "referer": r.Referer(), - } - - if reqID := getRequestID(r.Context()); reqID != "" { - logFields["request_id"] = reqID - } - - entry.Logger = entry.Logger.WithFields(logFields) - entry.Logger.Infoln("request started") - return entry -} - -type structuredLoggerEntry struct { - Logger logrus.FieldLogger -} - -func (l *structuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) { - l.Logger = l.Logger.WithFields(logrus.Fields{ - "status": status, - "duration": elapsed.Nanoseconds(), - }) - - l.Logger.Info("request completed") -} - -func (l *structuredLoggerEntry) Panic(v interface{}, stack []byte) { - l.Logger.WithFields(logrus.Fields{ - "stack": string(stack), - "panic": fmt.Sprintf("%+v", v), - }).Panic("unhandled request panic") -} - -func getLogEntry(r *http.Request) logrus.FieldLogger { - entry, _ := chimiddleware.GetLogEntry(r).(*structuredLoggerEntry) - if entry == nil { - return logrus.NewEntry(logrus.StandardLogger()) - } - return entry.Logger -} - -func logEntrySetField(r *http.Request, key string, value interface{}) logrus.FieldLogger { - if entry, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*structuredLoggerEntry); ok { - entry.Logger = entry.Logger.WithField(key, value) - return entry.Logger - } - return nil -} - -func logEntrySetFields(r *http.Request, fields logrus.Fields) logrus.FieldLogger { - if entry, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*structuredLoggerEntry); ok { - entry.Logger = entry.Logger.WithFields(fields) - return entry.Logger - } - return nil -} diff --git a/api/logout.go b/api/logout.go deleted file mode 100644 index 77b4f2fa5..000000000 --- a/api/logout.go +++ /dev/null @@ -1,35 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// Logout is the endpoint for logging out a user and thereby revoking any refresh tokens -func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - instanceID := getInstanceID(ctx) - config := getConfig(ctx) - - a.clearCookieTokens(config, w) - - u, err := getUserFromClaims(ctx, a.db) - if err != nil { - return unauthorizedError("Invalid user").WithInternalError(err) - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, u, models.LogoutAction, nil); terr != nil { - return terr - } - return models.Logout(tx, instanceID, u.ID) - }) - if err != nil { - return internalServerError("Error logging out user").WithInternalError(err) - } - - w.WriteHeader(http.StatusNoContent) - return nil -} diff --git a/api/magic_link.go b/api/magic_link.go deleted file mode 100644 index c4d6b58e0..000000000 --- a/api/magic_link.go +++ /dev/null @@ -1,111 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "io/ioutil" - "net/http" - "strings" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/sethvargo/go-password/password" -) - -// MagicLinkParams holds the parameters for a magic link request -type MagicLinkParams struct { - Email string `json:"email"` -} - -// MagicLink sends a recovery email -func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") - } - - instanceID := getInstanceID(ctx) - params := &MagicLinkParams{} - jsonDecoder := json.NewDecoder(r.Body) - err := jsonDecoder.Decode(params) - if err != nil { - return badRequestError("Could not read verification params: %v", err) - } - - if params.Email == "" { - return unprocessableEntityError("Password recovery requires an email") - } - if err := a.validateEmail(ctx, params.Email); err != nil { - return err - } - - aud := a.requestAud(ctx, r) - user, err := models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - if err != nil { - if models.IsNotFoundError(err) { - // User doesn't exist, sign them up with temporary password - password, err := password.Generate(64, 10, 0, false, true) - if err != nil { - internalServerError("error creating user").WithInternalError(err) - } - newBodyContent := `{"email":"` + params.Email + `","password":"` + password + `"}` - r.Body = ioutil.NopCloser(strings.NewReader(newBodyContent)) - r.ContentLength = int64(len(newBodyContent)) - - fakeResponse := &responseStub{} - if config.Mailer.Autoconfirm { - // signups are autoconfirmed, send magic link after signup - if err := a.Signup(fakeResponse, r); err != nil { - return err - } - newBodyContent := `{"email":"` + params.Email + `"}` - r.Body = ioutil.NopCloser(strings.NewReader(newBodyContent)) - r.ContentLength = int64(len(newBodyContent)) - return a.MagicLink(w, r) - } - // otherwise confirmation email already contains 'magic link' - if err := a.Signup(fakeResponse, r); err != nil { - return err - } - - return sendJSON(w, http.StatusOK, make(map[string]string)) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); terr != nil { - return terr - } - - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer) - }) - if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") - } - return internalServerError("Error sending magic link").WithInternalError(err) - } - - return sendJSON(w, http.StatusOK, make(map[string]string)) -} - -// responseStub only implement http responsewriter for ignoring -// incoming data from methods where it passed -type responseStub struct { -} - -func (rw *responseStub) Header() http.Header { - return http.Header{} -} - -func (rw *responseStub) Write(data []byte) (int, error) { - return 1, nil -} - -func (rw *responseStub) WriteHeader(statusCode int) { -} diff --git a/api/mail.go b/api/mail.go deleted file mode 100644 index edf8b0f67..000000000 --- a/api/mail.go +++ /dev/null @@ -1,282 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/netlify/gotrue/crypto" - "github.com/netlify/gotrue/mailer" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" - "github.com/sethvargo/go-password/password" -) - -var ( - MaxFrequencyLimitError error = errors.New("Frequency limit reached") - configFile = "" -) - -type GenerateLinkParams struct { - Type string `json:"type"` - Email string `json:"email"` - Password string `json:"password"` - Data map[string]interface{} `json:"data"` - RedirectTo string `json:"redirect_to"` -} - -func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - mailer := a.Mailer(ctx) - instanceID := getInstanceID(ctx) - adminUser := getAdminUser(ctx) - - params := &GenerateLinkParams{} - jsonDecoder := json.NewDecoder(r.Body) - - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read body: %v", err) - } - - if err := a.validateEmail(ctx, params.Email); err != nil { - return err - } - - aud := a.requestAud(ctx, r) - user, err := models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - if err != nil { - if models.IsNotFoundError(err) { - if params.Type == "magiclink" { - params.Type = "signup" - params.Password, err = password.Generate(64, 10, 0, false, true) - if err != nil { - return internalServerError("error creating user").WithInternalError(err) - } - } else if params.Type == "recovery" { - return notFoundError(err.Error()) - } - } else { - return internalServerError("Database error finding user").WithInternalError(err) - } - } - - var url string - referrer := a.getRedirectURLOrReferrer(r, params.RedirectTo) - now := time.Now() - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - switch params.Type { - case "magiclink", "recovery": - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); terr != nil { - return terr - } - user.RecoveryToken = crypto.SecureToken() - user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") - case "invite": - if user != nil { - if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) - } - } else { - signupParams := &SignupParams{ - Email: params.Email, - Data: params.Data, - Provider: "email", - Aud: aud, - } - user, terr = a.signupNewUser(ctx, tx, signupParams) - if terr != nil { - return terr - } - } - if terr = models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserInvitedAction, map[string]interface{}{ - "user_id": user.ID, - "user_email": user.Email, - }); terr != nil { - return terr - } - user.ConfirmationToken = crypto.SecureToken() - user.ConfirmationSentAt = &now - user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") - case "signup": - if user != nil { - if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) - } - if err := user.UpdateUserMetaData(tx, params.Data); err != nil { - return internalServerError("Database error updating user").WithInternalError(err) - } - } else { - if params.Password == "" { - return unprocessableEntityError("Signup requires a valid password") - } - if len(params.Password) < config.PasswordMinLength { - return unprocessableEntityError(fmt.Sprintf("Password should be at least %d characters", config.PasswordMinLength)) - } - signupParams := &SignupParams{ - Email: params.Email, - Password: params.Password, - Data: params.Data, - Provider: "email", - Aud: aud, - } - user, terr = a.signupNewUser(ctx, tx, signupParams) - if terr != nil { - return terr - } - } - user.ConfirmationToken = crypto.SecureToken() - user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") - default: - return badRequestError("Invalid email action link type requested: %v", params.Type) - } - - if terr != nil { - return terr - } - - url, terr = mailer.GetEmailActionLink(user, params.Type, referrer) - if terr != nil { - return terr - } - return nil - }) - - if err != nil { - return err - } - - resp := make(map[string]interface{}) - u, err := json.Marshal(user) - if err != nil { - return internalServerError("User serialization error").WithInternalError(err) - } - if err = json.Unmarshal(u, &resp); err != nil { - return internalServerError("User serialization error").WithInternalError(err) - } - resp["action_link"] = url - - return sendJSON(w, http.StatusOK, resp) -} - -func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string) error { - if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) { - return MaxFrequencyLimitError - } - - oldToken := u.ConfirmationToken - u.ConfirmationToken = crypto.SecureToken() - now := time.Now() - if err := mailer.ConfirmationMail(u, referrerURL); err != nil { - u.ConfirmationToken = oldToken - return errors.Wrap(err, "Error sending confirmation email") - } - u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") -} - -func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string) error { - oldToken := u.ConfirmationToken - u.ConfirmationToken = crypto.SecureToken() - now := time.Now() - if err := mailer.InviteMail(u, referrerURL); err != nil { - u.ConfirmationToken = oldToken - return errors.Wrap(err, "Error sending invite email") - } - u.InvitedAt = &now - u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") -} - -func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string) error { - if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) { - return MaxFrequencyLimitError - } - - oldToken := u.RecoveryToken - u.RecoveryToken = crypto.SecureToken() - now := time.Now() - if err := mailer.RecoveryMail(u, referrerURL); err != nil { - u.RecoveryToken = oldToken - return errors.Wrap(err, "Error sending recovery email") - } - u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") -} - -func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string) error { - // since Magic Link is just a recovery with a different template and behaviour - // around new users we will reuse the recovery db timer to prevent potential abuse - if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) { - return MaxFrequencyLimitError - } - - oldToken := u.RecoveryToken - u.RecoveryToken = crypto.SecureToken() - now := time.Now() - if err := mailer.MagicLinkMail(u, referrerURL); err != nil { - u.RecoveryToken = oldToken - return errors.Wrap(err, "Error sending magic link email") - } - u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") -} - -// sendSecureEmailChange sends out an email change token each to the old and new emails. -func (a *API) sendSecureEmailChange(tx *storage.Connection, u *models.User, mailer mailer.Mailer, email string, referrerURL string) error { - u.EmailChangeTokenCurrent, u.EmailChangeTokenNew = crypto.SecureToken(), crypto.SecureToken() - u.EmailChange = email - u.EmailChangeConfirmStatus = zeroConfirmation - now := time.Now() - if err := mailer.EmailChangeMail(u, referrerURL); err != nil { - return err - } - - u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( - u, - "email_change_token_current", - "email_change_token_new", - "email_change", - "email_change_sent_at", - "email_change_confirm_status", - ), "Database error updating user for email change") -} - -// sendEmailChange sends out an email change token to the new email. -func (a *API) sendEmailChange(tx *storage.Connection, u *models.User, mailer mailer.Mailer, email string, referrerURL string) error { - u.EmailChangeTokenNew = crypto.SecureToken() - u.EmailChange = email - u.EmailChangeConfirmStatus = zeroConfirmation - now := time.Now() - if err := mailer.EmailChangeMail(u, referrerURL); err != nil { - return err - } - - u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( - u, - "email_change_token_new", - "email_change", - "email_change_sent_at", - "email_change_confirm_status", - ), "Database error updating user for email change") -} - -func (a *API) validateEmail(ctx context.Context, email string) error { - if email == "" { - return unprocessableEntityError("An email address is required") - } - mailer := a.Mailer(ctx) - if err := mailer.ValidateEmail(email); err != nil { - return unprocessableEntityError("Unable to validate email address: " + err.Error()) - } - return nil -} diff --git a/api/middleware.go b/api/middleware.go deleted file mode 100644 index 54e6fdc07..000000000 --- a/api/middleware.go +++ /dev/null @@ -1,251 +0,0 @@ -package api - -import ( - "bytes" - "context" - "encoding/json" - "io" - "io/ioutil" - "net/http" - "strings" - "time" - - "github.com/netlify/gotrue/security" - "github.com/sirupsen/logrus" - - "github.com/didip/tollbooth/v5" - "github.com/didip/tollbooth/v5/limiter" - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/models" -) - -const ( - jwsSignatureHeaderName = "x-nf-sign" -) - -type FunctionHooks map[string][]string - -type NetlifyMicroserviceClaims struct { - jwt.StandardClaims - SiteURL string `json:"site_url"` - InstanceID string `json:"id"` - NetlifyID string `json:"netlify_id"` - FunctionHooks FunctionHooks `json:"function_hooks"` -} - -func (f *FunctionHooks) UnmarshalJSON(b []byte) error { - var raw map[string][]string - err := json.Unmarshal(b, &raw) - if err == nil { - *f = FunctionHooks(raw) - return nil - } - // If unmarshaling into map[string][]string fails, try legacy format. - var legacy map[string]string - err = json.Unmarshal(b, &legacy) - if err != nil { - return err - } - if *f == nil { - *f = make(FunctionHooks) - } - for event, hook := range legacy { - (*f)[event] = []string{hook} - } - return nil -} - -func addGetBody(w http.ResponseWriter, req *http.Request) (context.Context, error) { - if req.Method == http.MethodGet { - return req.Context(), nil - } - - if req.Body == nil || req.Body == http.NoBody { - return nil, badRequestError("request must provide a body") - } - - buf, err := ioutil.ReadAll(req.Body) - if err != nil { - return nil, internalServerError("Error reading body").WithInternalError(err) - } - req.GetBody = func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewReader(buf)), nil - } - req.Body, _ = req.GetBody() - return req.Context(), nil -} - -func (a *API) loadJWSSignatureHeader(w http.ResponseWriter, r *http.Request) (context.Context, error) { - ctx := r.Context() - signature := r.Header.Get(jwsSignatureHeaderName) - if signature == "" { - return nil, badRequestError("Operator microservice headers missing") - } - return withSignature(ctx, signature), nil -} - -func (a *API) loadInstanceConfig(w http.ResponseWriter, r *http.Request) (context.Context, error) { - ctx := r.Context() - config := a.getConfig(ctx) - - signature := getSignature(ctx) - if signature == "" { - return nil, badRequestError("Operator signature missing") - } - - claims := NetlifyMicroserviceClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - _, err := p.ParseWithClaims(signature, &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(config.JWT.Secret), nil - }) - if err != nil { - return nil, badRequestError("Operator microservice signature is invalid: %v", err) - } - - if claims.InstanceID == "" { - return nil, badRequestError("Instance ID is missing") - } - instanceID, err := uuid.FromString(claims.InstanceID) - if err != nil { - return nil, badRequestError("Instance ID is not a valid UUID") - } - - logEntrySetField(r, "instance_id", instanceID) - logEntrySetField(r, "netlify_id", claims.NetlifyID) - instance, err := models.GetInstance(a.db, instanceID) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError("Unable to locate site configuration") - } - return nil, internalServerError("Database error loading instance").WithInternalError(err) - } - - config, err = instance.Config() - if err != nil { - return nil, internalServerError("Error loading environment config").WithInternalError(err) - } - - if claims.SiteURL != "" { - config.SiteURL = claims.SiteURL - } - logEntrySetField(r, "site_url", config.SiteURL) - - ctx = withNetlifyID(ctx, claims.NetlifyID) - ctx = withFunctionHooks(ctx, claims.FunctionHooks) - - ctx, err = WithInstanceConfig(ctx, config, instanceID) - if err != nil { - return nil, internalServerError("Error loading instance config").WithInternalError(err) - } - - return ctx, nil -} - -func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { - return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { - c := req.Context() - if limitHeader := a.config.RateLimitHeader; limitHeader != "" { - key := req.Header.Get(a.config.RateLimitHeader) - err := tollbooth.LimitByKeys(lmt, []string{key}) - if err != nil { - return c, httpError(http.StatusTooManyRequests, "Rate limit exceeded") - } - } - return c, nil - } -} - -func (a *API) limitEmailSentHandler() middlewareHandler { - // limit per hour - freq := a.config.RateLimitEmailSent / (60 * 60) - lmt := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"}) - return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { - c := req.Context() - config := a.getConfig(c) - if config.External.Email.Enabled && !config.Mailer.Autoconfirm { - if req.Method == "PUT" || req.Method == "POST" { - res := make(map[string]interface{}) - bodyBytes, err := ioutil.ReadAll(req.Body) - if err != nil { - return c, internalServerError("Error invalid request body").WithInternalError(err) - } - req.Body.Close() - req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) - - jsonDecoder := json.NewDecoder(bytes.NewBuffer(bodyBytes)) - if err := jsonDecoder.Decode(&res); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) - } - - if _, ok := res["email"]; !ok { - // email not in POST body - return c, nil - } - - if err := tollbooth.LimitByKeys(lmt, []string{"email_functions"}); err != nil { - return c, httpError(http.StatusTooManyRequests, "Rate limit exceeded") - } - } - } - return c, nil - } -} - -func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { - ctx := req.Context() - t, err := a.extractBearerToken(w, req) - if err != nil || t == "" { - return nil, err - } - - ctx, err = a.parseJWTClaims(t, req, w) - if err != nil { - return nil, err - } - - return a.requireAdmin(ctx, w, req) -} - -func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) { - ctx := req.Context() - config := a.getConfig(ctx) - - if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") - } - - return ctx, nil -} - -func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.Context, error) { - ctx := req.Context() - config := a.getConfig(ctx) - if !config.Security.Captcha.Enabled { - return ctx, nil - } - if config.Security.Captcha.Provider != "hcaptcha" { - logrus.WithField("provider", config.Security.Captcha.Provider).Warn("Unsupported captcha provider") - return nil, internalServerError("server misconfigured") - } - secret := strings.TrimSpace(config.Security.Captcha.Secret) - if secret == "" { - return nil, internalServerError("server misconfigured") - } - verificationResult, err := security.VerifyRequest(req, secret) - if err != nil { - logrus.WithField("err", err).Infof("failed to validate result") - return nil, internalServerError("request validation failure") - } - if verificationResult == security.VerificationProcessFailure { - return nil, internalServerError("request validation failure") - } else if verificationResult == security.UserRequestFailed { - return nil, badRequestError("request disallowed") - } - if verificationResult == security.SuccessfullyVerified { - return ctx, nil - } - return nil, internalServerError("") -} diff --git a/api/middleware_test.go b/api/middleware_test.go deleted file mode 100644 index 1c99037c3..000000000 --- a/api/middleware_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -const ( - HCaptchaSecret string = "0x0000000000000000000000000000000000000000" - HCaptchaResponse string = "10000000-aaaa-bbbb-cccc-000000000001" -) - -type MiddlewareTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestHCaptcha(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &MiddlewareTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *MiddlewareTestSuite) TestVerifyCaptchaValid() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "secret", - "gotrue_meta_security": map[string]interface{}{ - "hcaptcha_token": HCaptchaResponse, - }, - })) - - ts.Config.Security.Captcha.Enabled = true - ts.Config.Security.Captcha.Provider = "hcaptcha" - ts.Config.Security.Captcha.Secret = HCaptchaSecret - - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Set("Content-Type", "application/json") - beforeCtx, err := WithInstanceConfig(req.Context(), ts.Config, ts.instanceID) - require.NoError(ts.T(), err) - - req = req.WithContext(beforeCtx) - - w := httptest.NewRecorder() - - afterCtx, err := ts.API.verifyCaptcha(w, req) - require.NoError(ts.T(), err) - - body, err := ioutil.ReadAll(req.Body) - - // re-initialize buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "secret", - "gotrue_meta_security": map[string]interface{}{ - "hcaptcha_token": HCaptchaResponse, - }, - })) - - // check if body is the same - require.Equal(ts.T(), body, buffer.Bytes()) - require.Equal(ts.T(), afterCtx, beforeCtx) -} - -func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { - cases := []struct { - desc string - captchaConf *conf.CaptchaConfiguration - expectedCode int - expectedMsg string - }{ - { - "Unsupported provider", - &conf.CaptchaConfiguration{ - Enabled: true, - Provider: "test", - }, - http.StatusInternalServerError, - "server misconfigured", - }, - { - "Missing secret", - &conf.CaptchaConfiguration{ - Enabled: true, - Provider: "hcaptcha", - Secret: "", - }, - http.StatusInternalServerError, - "server misconfigured", - }, - { - "Captcha validation failed", - &conf.CaptchaConfiguration{ - Enabled: true, - Provider: "hcaptcha", - Secret: "test", - }, - http.StatusInternalServerError, - "request validation failure", - }, - } - for _, c := range cases { - ts.Run(c.desc, func() { - ts.Config.Security.Captcha = *c.captchaConf - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "secret", - "gotrue_meta_security": map[string]interface{}{ - "hcaptcha_token": HCaptchaResponse, - }, - })) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Set("Content-Type", "application/json") - ctx, err := WithInstanceConfig(req.Context(), ts.Config, ts.instanceID) - require.NoError(ts.T(), err) - - req = req.WithContext(ctx) - - w := httptest.NewRecorder() - - _, err = ts.API.verifyCaptcha(w, req) - require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).Code) - require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) - }) - } -} - -func TestFunctionHooksUnmarshalJSON(t *testing.T) { - tests := []struct { - in string - ok bool - }{ - {`{ "signup" : "identity-signup" }`, true}, - {`{ "signup" : ["identity-signup"] }`, true}, - {`{ "signup" : {"foo" : "bar"} }`, false}, - } - for _, tt := range tests { - t.Run(tt.in, func(t *testing.T) { - var f FunctionHooks - err := json.Unmarshal([]byte(tt.in), &f) - if tt.ok { - assert.NoError(t, err) - assert.Equal(t, FunctionHooks{"signup": {"identity-signup"}}, f) - } else { - assert.Error(t, err) - } - }) - } -} diff --git a/api/otp.go b/api/otp.go deleted file mode 100644 index 6b8a86b81..000000000 --- a/api/otp.go +++ /dev/null @@ -1,137 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - "strings" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/sethvargo/go-password/password" -) - -// OtpParams contains the request body params for the otp endpoint -type OtpParams struct { - Email string `json:"email"` - Phone string `json:"phone"` - CreateUser bool `json:"create_user"` -} - -// SmsParams contains the request body params for sms otp -type SmsParams struct { - Phone string `json:"phone"` -} - -// Otp returns the MagicLink or SmsOtp handler based on the request body params -func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { - params := &OtpParams{ - CreateUser: true, - } - body, err := ioutil.ReadAll(r.Body) - jsonDecoder := json.NewDecoder(bytes.NewReader(body)) - if err = jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read verification params: %v", err) - } - if params.Email != "" && params.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") - } - - r.Body = ioutil.NopCloser(strings.NewReader(string(body))) - - if !a.shouldCreateUser(r, params) { - return badRequestError("Signups not allowed for otp") - } - - if params.Email != "" { - return a.MagicLink(w, r) - } else if params.Phone != "" { - return a.SmsOtp(w, r) - } - - return otpError("unsupported_otp_type", "") -} - -// SmsOtp sends the user an otp via sms -func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") - } - - instanceID := getInstanceID(ctx) - params := &SmsParams{} - jsonDecoder := json.NewDecoder(r.Body) - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) - } - - params.Phone = a.formatPhoneNumber(params.Phone) - - if isValid := a.validateE164Format(params.Phone); !isValid { - return badRequestError("Invalid format: Phone number should follow the E.164 format") - } - - aud := a.requestAud(ctx, r) - - user, uerr := models.FindUserByPhoneAndAudience(a.db, instanceID, params.Phone, aud) - if uerr != nil { - // if user does not exists, sign up the user - if models.IsNotFoundError(uerr) { - password, err := password.Generate(64, 10, 0, false, true) - if err != nil { - internalServerError("error creating user").WithInternalError(err) - } - newBodyContent := `{"phone":"` + params.Phone + `","password":"` + password + `"}` - r.Body = ioutil.NopCloser(strings.NewReader(newBodyContent)) - r.ContentLength = int64(len(newBodyContent)) - - fakeResponse := &responseStub{} - - if err := a.Signup(fakeResponse, r); err != nil { - return err - } - return sendJSON(w, http.StatusOK, make(map[string]string)) - } - return internalServerError("Database error finding user").WithInternalError(uerr) - } - - err := a.db.Transaction(func(tx *storage.Connection) error { - if err := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); err != nil { - return err - } - - if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone); err != nil { - return badRequestError("Error sending sms otp: %v", err) - } - return nil - }) - - if err != nil { - return err - } - - return sendJSON(w, http.StatusOK, make(map[string]string)) -} - -func (a *API) shouldCreateUser(r *http.Request, params *OtpParams) bool { - if !params.CreateUser { - ctx := r.Context() - instanceID := getInstanceID(ctx) - aud := a.requestAud(ctx, r) - var err error - if params.Email != "" { - _, err = models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - } else if params.Phone != "" { - _, err = models.FindUserByPhoneAndAudience(a.db, instanceID, params.Phone, aud) - } - - if err != nil && models.IsNotFoundError(err) { - return false - } - } - return true -} diff --git a/api/otp_test.go b/api/otp_test.go deleted file mode 100644 index 97851b8d0..000000000 --- a/api/otp_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type OtpTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestOtp(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &OtpTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *OtpTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) -} - -func (ts *OtpTestSuite) TestOtp() { - cases := []struct { - desc string - params OtpParams - expected struct { - code int - response map[string]interface{} - } - }{ - { - "Test Success Magiclink Otp", - OtpParams{ - Email: "test@example.com", - CreateUser: true, - }, - struct { - code int - response map[string]interface{} - }{ - http.StatusOK, - make(map[string]interface{}), - }, - }, - { - "Test Failure Pass Both Email & Phone", - OtpParams{ - Email: "test@example.com", - Phone: "123456789", - CreateUser: true, - }, - struct { - code int - response map[string]interface{} - }{ - http.StatusBadRequest, - map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Only an email address or phone number should be provided", - }, - }, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) - - req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), c.expected.code, w.Code) - - data := make(map[string]interface{}) - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - // response should be empty - assert.Equal(ts.T(), data, c.expected.response) - }) - } -} - -func (ts *OtpTestSuite) TestNoSignupsForOtp() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "newuser@example.com", - "create_user": false, - })) - - req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusBadRequest, w.Code) - - data := make(map[string]interface{}) - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - // response should be empty - assert.Equal(ts.T(), data, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Signups not allowed for otp", - }) -} diff --git a/api/phone.go b/api/phone.go deleted file mode 100644 index c36397331..000000000 --- a/api/phone.go +++ /dev/null @@ -1,67 +0,0 @@ -package api - -import ( - "context" - "fmt" - "regexp" - "strings" - "time" - - "github.com/netlify/gotrue/api/sms_provider" - "github.com/netlify/gotrue/crypto" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -const e164Format = `^[1-9]\d{1,14}$` -const defaultSmsMessage = "Your code is %v" - -// validateE165Format checks if phone number follows the E.164 format -func (a *API) validateE164Format(phone string) bool { - // match should never fail as long as regexp is valid - matched, _ := regexp.Match(e164Format, []byte(phone)) - return matched -} - -// formatPhoneNumber removes "+" and whitespaces in a phone number -func (a *API) formatPhoneNumber(phone string) string { - return strings.ReplaceAll(strings.Trim(phone, "+"), " ", "") -} - -func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone string) error { - config := a.getConfig(ctx) - - if user.ConfirmationSentAt != nil && !user.ConfirmationSentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { - return MaxFrequencyLimitError - } - - oldToken := user.ConfirmationToken - otp, err := crypto.GenerateOtp(config.Sms.OtpLength) - if err != nil { - return internalServerError("error generating otp").WithInternalError(err) - } - user.ConfirmationToken = otp - - smsProvider, err := sms_provider.GetSmsProvider(*config) - if err != nil { - return err - } - - var message string - if config.Sms.Template == "" { - message = fmt.Sprintf(defaultSmsMessage, user.ConfirmationToken) - } else { - message = strings.Replace(config.Sms.Template, "{{ .Code }}", user.ConfirmationToken, -1) - } - - if serr := smsProvider.SendSms(phone, message); serr != nil { - user.ConfirmationToken = oldToken - return serr - } - - now := time.Now() - user.ConfirmationSentAt = &now - - return errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") -} diff --git a/api/provider/apple.go b/api/provider/apple.go deleted file mode 100644 index 148841af9..000000000 --- a/api/provider/apple.go +++ /dev/null @@ -1,192 +0,0 @@ -package provider - -import ( - "context" - "crypto/rsa" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - - "github.com/golang-jwt/jwt" - "github.com/lestrrat-go/jwx/jwk" - "github.com/netlify/gotrue/conf" - "golang.org/x/oauth2" -) - -const ( - defaultAppleAPIBase = "appleid.apple.com" - authEndpoint = "/auth/authorize" - tokenEndpoint = "/auth/token" - - scopeEmail = "email" - scopeName = "name" - - appleAudOrIss = "https://appleid.apple.com" - idTokenVerificationKeyEndpoint = "/auth/keys" -) - -// AppleProvider stores the custom config for apple provider -type AppleProvider struct { - *oauth2.Config - httpClient *http.Client - UserInfoURL string -} - -type appleName struct { - FirstName string `json:"firstName"` - LastName string `json:"lastName"` -} - -type appleUser struct { - Name appleName `json:"name"` - Email string `json:"email"` -} - -type idTokenClaims struct { - jwt.StandardClaims - AccessTokenHash string `json:"at_hash"` - AuthTime int `json:"auth_time"` - Email string `json:"email"` - IsPrivateEmail bool `json:"is_private_email,string"` - Sub string `json:"sub"` -} - -// NewAppleProvider creates a Apple account provider. -func NewAppleProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { - return nil, err - } - - authHost := chooseHost(ext.URL, defaultAppleAPIBase) - - return &AppleProvider{ - Config: &oauth2.Config{ - ClientID: ext.ClientID, - ClientSecret: ext.Secret, - Endpoint: oauth2.Endpoint{ - AuthURL: authHost + authEndpoint, - TokenURL: authHost + tokenEndpoint, - }, - Scopes: []string{ - scopeEmail, - scopeName, - }, - RedirectURL: ext.RedirectURI, - }, - UserInfoURL: authHost + idTokenVerificationKeyEndpoint, - }, nil -} - -// GetOAuthToken returns the apple provider access token -func (p AppleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("client_id", p.ClientID), - oauth2.SetAuthURLParam("secret", p.ClientSecret), - } - return p.Exchange(oauth2.NoContext, code, opts...) -} - -func (p AppleProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string { - opts := make([]oauth2.AuthCodeOption, 0, 1) - opts = append(opts, oauth2.SetAuthURLParam("response_mode", "form_post")) - authURL := p.Config.AuthCodeURL(state, opts...) - if authURL != "" { - if u, err := url.Parse(authURL); err != nil { - u.RawQuery = strings.ReplaceAll(u.RawQuery, "+", "%20") - authURL = u.String() - } - } - return authURL -} - -// GetUserData returns the user data fetched from the apple provider -func (p AppleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { - var user *UserProvidedData - if tok.AccessToken == "" { - return &UserProvidedData{}, nil - } - if idToken := tok.Extra("id_token"); idToken != nil { - idToken, err := jwt.ParseWithClaims(idToken.(string), &idTokenClaims{}, func(t *jwt.Token) (interface{}, error) { - kid := t.Header["kid"].(string) - claims := t.Claims.(*idTokenClaims) - vErr := new(jwt.ValidationError) - if !claims.VerifyAudience(p.ClientID, true) { - vErr.Inner = fmt.Errorf("incorrect audience") - vErr.Errors |= jwt.ValidationErrorAudience - } - if !claims.VerifyIssuer(appleAudOrIss, true) { - vErr.Inner = fmt.Errorf("incorrect issuer") - vErr.Errors |= jwt.ValidationErrorIssuer - } - if vErr.Errors > 0 { - return nil, vErr - } - - // per OpenID Connect Core 1.0 §3.2.2.9, Access Token Validation - hash := sha256.Sum256([]byte(tok.AccessToken)) - halfHash := hash[0:(len(hash) / 2)] - encodedHalfHash := base64.RawURLEncoding.EncodeToString(halfHash) - if encodedHalfHash != claims.AccessTokenHash { - vErr.Inner = fmt.Errorf(`invalid identity token`) - vErr.Errors |= jwt.ValidationErrorClaimsInvalid - return nil, vErr - } - - // get the public key for verifying the identity token signature - set, err := jwk.FetchHTTP(p.UserInfoURL, jwk.WithHTTPClient(http.DefaultClient)) - if err != nil { - return nil, err - } - selectedKey := set.Keys[0] - for _, key := range set.Keys { - if key.KeyID() == kid { - selectedKey = key - break - } - } - pubKeyIface, _ := selectedKey.Materialize() - pubKey, ok := pubKeyIface.(*rsa.PublicKey) - if !ok { - return nil, fmt.Errorf(`expected RSA public key from %s`, p.UserInfoURL) - } - return pubKey, nil - }) - if err != nil { - return &UserProvidedData{}, err - } - user = &UserProvidedData{ - Emails: []Email{{ - Email: idToken.Claims.(*idTokenClaims).Email, - Verified: true, - Primary: true, - }}, - Metadata: &Claims{ - Issuer: p.UserInfoURL, - Subject: idToken.Claims.(*idTokenClaims).Sub, - Email: idToken.Claims.(*idTokenClaims).Email, - EmailVerified: true, - - // To be deprecated - ProviderId: idToken.Claims.(*idTokenClaims).Sub, - }, - } - } - return user, nil -} - -// ParseUser parses the apple user's info -func (p AppleProvider) ParseUser(data string, userData *UserProvidedData) error { - u := &appleUser{} - err := json.Unmarshal([]byte(data), u) - if err != nil { - return err - } - - userData.Metadata.Name = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) - userData.Metadata.FullName = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) - return nil -} diff --git a/api/provider/azure.go b/api/provider/azure.go deleted file mode 100644 index fe78cf926..000000000 --- a/api/provider/azure.go +++ /dev/null @@ -1,96 +0,0 @@ -package provider - -import ( - "context" - "errors" - "strings" - - "github.com/netlify/gotrue/conf" - "golang.org/x/oauth2" -) - -const ( - defaultAzureAuthBase = "login.microsoftonline.com" - defaultAzureAPIBase = "graph.microsoft.com" -) - -type azureProvider struct { - *oauth2.Config - APIPath string -} - -type azureUser struct { - Name string `json:"name"` - Email string `json:"email"` - Sub string `json:"sub"` -} - -type azureEmail struct { - Email string `json:"email"` - Primary bool `json:"is_primary"` - Verified bool `json:"is_confirmed"` -} - -// NewAzureProvider creates a Azure account provider. -func NewAzureProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { - return nil, err - } - - authHost := chooseHost(ext.URL, defaultAzureAuthBase) - apiPath := chooseHost(ext.URL, defaultAzureAPIBase) - - oauthScopes := []string{"openid"} - - if scopes != "" { - oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) - } - - return &azureProvider{ - Config: &oauth2.Config{ - ClientID: ext.ClientID, - ClientSecret: ext.Secret, - Endpoint: oauth2.Endpoint{ - AuthURL: authHost + "/common/oauth2/v2.0/authorize", - TokenURL: authHost + "/common/oauth2/v2.0/token", - }, - RedirectURL: ext.RedirectURI, - Scopes: oauthScopes, - }, - APIPath: apiPath, - }, nil -} - -func (g azureProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) -} - -func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { - var u azureUser - if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/oidc/userinfo", &u); err != nil { - return nil, err - } - - if u.Email == "" { - return nil, errors.New("Unable to find email with Azure provider") - } - - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.Sub, - Name: u.Name, - Email: u.Email, - EmailVerified: true, - - // To be deprecated - FullName: u.Name, - ProviderId: u.Sub, - }, - Emails: []Email{{ - Email: u.Email, - Verified: true, - Primary: true, - }}, - }, nil -} diff --git a/api/provider/google.go b/api/provider/google.go deleted file mode 100644 index e2e060ecb..000000000 --- a/api/provider/google.go +++ /dev/null @@ -1,102 +0,0 @@ -package provider - -import ( - "context" - "errors" - "strings" - - "github.com/netlify/gotrue/conf" - "golang.org/x/oauth2" -) - -const ( - defaultGoogleAuthBase = "accounts.google.com" - defaultGoogleAPIBase = "www.googleapis.com" -) - -type googleProvider struct { - *oauth2.Config - APIPath string -} - -type googleUser struct { - ID string `json:"id"` - Name string `json:"name"` - AvatarURL string `json:"picture"` - Email string `json:"email"` - EmailVerified bool `json:"verified_email"` -} - -// NewGoogleProvider creates a Google account provider. -func NewGoogleProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { - return nil, err - } - - authHost := chooseHost(ext.URL, defaultGoogleAuthBase) - apiPath := chooseHost(ext.URL, defaultGoogleAPIBase) + "/userinfo/v2/me" - - oauthScopes := []string{ - "email", - "profile", - } - - if scopes != "" { - oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) - } - - return &googleProvider{ - Config: &oauth2.Config{ - ClientID: ext.ClientID, - ClientSecret: ext.Secret, - Endpoint: oauth2.Endpoint{ - AuthURL: authHost + "/o/oauth2/auth", - TokenURL: authHost + "/o/oauth2/token", - }, - Scopes: oauthScopes, - RedirectURL: ext.RedirectURI, - }, - APIPath: apiPath, - }, nil -} - -func (g googleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) -} - -func (g googleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { - var u googleUser - if err := makeRequest(ctx, tok, g.Config, g.APIPath, &u); err != nil { - return nil, err - } - - data := &UserProvidedData{} - - if u.Email != "" { - data.Emails = append(data.Emails, Email{ - Email: u.Email, - Verified: u.EmailVerified, - Primary: true, - }) - } - - if len(data.Emails) <= 0 { - return nil, errors.New("Unable to find email with Google provider") - } - - data.Metadata = &Claims{ - Issuer: g.APIPath, - Subject: u.ID, - Name: u.Name, - Picture: u.AvatarURL, - Email: u.Email, - EmailVerified: u.EmailVerified, - - // To be deprecated - AvatarURL: u.AvatarURL, - FullName: u.Name, - ProviderId: u.ID, - } - - return data, nil -} diff --git a/api/provider/provider.go b/api/provider/provider.go deleted file mode 100644 index afb36c5ee..000000000 --- a/api/provider/provider.go +++ /dev/null @@ -1,114 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - - "golang.org/x/oauth2" -) - -type Claims struct { - // Reserved claims - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Aud string `json:"aud,omitempty"` - Iat float64 `json:"iat,omitempty"` - Exp float64 `json:"exp,omitempty"` - - // Default profile claims - Name string `json:"name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - GivenName string `json:"given_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - NickName string `json:"nickname,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - ZoneInfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` - Email string `json:"email,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - Phone string `json:"phone,omitempty"` - PhoneVerified bool `json:"phone_verified,omitempty"` - - // Custom profile claims that are provider specific - CustomClaims map[string]interface{} `json:"custom_claims,omitempty"` - - // TODO: Deprecate in next major release - FullName string `json:"full_name,omitempty"` - AvatarURL string `json:"avatar_url,omitempty"` - Slug string `json:"slug,omitempty"` - ProviderId string `json:"provider_id,omitempty"` - UserNameKey string `json:"user_name,omitempty"` -} - -// ToMap converts the Claims struct to a map[string]interface{} -func (c *Claims) ToMap() (map[string]interface{}, error) { - m := make(map[string]interface{}) - cBytes, err := json.Marshal(c) - if err != nil { - return nil, err - } - err = json.Unmarshal(cBytes, &m) - if err != nil { - return nil, err - } - return m, nil -} - -// Email is a struct that provides information on whether an email is verified or is the primary email address -type Email struct { - Email string - Verified bool - Primary bool -} - -// UserProvidedData is a struct that contains the user's data returned from the oauth provider -type UserProvidedData struct { - Emails []Email - Metadata *Claims -} - -// Provider is an interface for interacting with external account providers -type Provider interface { - AuthCodeURL(string, ...oauth2.AuthCodeOption) string -} - -// OAuthProvider specifies additional methods needed for providers using OAuth -type OAuthProvider interface { - AuthCodeURL(string, ...oauth2.AuthCodeOption) string - GetUserData(context.Context, *oauth2.Token) (*UserProvidedData, error) - GetOAuthToken(string) (*oauth2.Token, error) -} - -func chooseHost(base, defaultHost string) string { - if base == "" { - return "https://" + defaultHost - } - - baseLen := len(base) - if base[baseLen-1] == '/' { - return base[:baseLen-1] - } - - return base -} - -func makeRequest(ctx context.Context, tok *oauth2.Token, g *oauth2.Config, url string, dst interface{}) error { - client := g.Client(ctx, tok) - res, err := client.Get(url) - if err != nil { - return err - } - defer res.Body.Close() - - if err := json.NewDecoder(res.Body).Decode(dst); err != nil { - return err - } - - return nil -} diff --git a/api/provider/saml.go b/api/provider/saml.go deleted file mode 100644 index 9b5af7b6e..000000000 --- a/api/provider/saml.go +++ /dev/null @@ -1,263 +0,0 @@ -package provider - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/pem" - "encoding/xml" - "errors" - "fmt" - "io/ioutil" - "math/big" - "net/http" - "net/url" - "strings" - "time" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - - "github.com/netlify/gotrue/conf" - saml2 "github.com/russellhaering/gosaml2" - "github.com/russellhaering/gosaml2/types" - dsig "github.com/russellhaering/goxmldsig" - "github.com/gofrs/uuid" - "golang.org/x/oauth2" -) - -type SamlProvider struct { - ServiceProvider *saml2.SAMLServiceProvider -} - -type ConfigX509KeyStore struct { - InstanceID uuid.UUID - DB *storage.Connection - Conf conf.SamlProviderConfiguration -} - -func getMetadata(url string) (*types.EntityDescriptor, error) { - res, err := http.Get(url) - if err != nil { - return nil, err - } - - if res.StatusCode >= 300 { - return nil, fmt.Errorf("Request failed with status %s", res.Status) - } - - rawMetadata, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, err - } - - metadata := &types.EntityDescriptor{} - err = xml.Unmarshal(rawMetadata, metadata) - if err != nil { - return nil, err - } - - // TODO: cache in memory - - return metadata, nil -} - -// NewSamlProvider creates a Saml account provider. -func NewSamlProvider(ext conf.SamlProviderConfiguration, db *storage.Connection, instanceId uuid.UUID) (*SamlProvider, error) { - if !ext.Enabled { - return nil, errors.New("SAML Provider is not enabled") - } - - if _, err := url.Parse(ext.MetadataURL); err != nil { - return nil, fmt.Errorf("Metadata URL is invalid: %+v", err) - } - - meta, err := getMetadata(ext.MetadataURL) - if err != nil { - return nil, fmt.Errorf("Fetching metadata failed: %+v", err) - } - - baseURI, err := url.Parse(strings.Trim(ext.APIBase, "/")) - if err != nil || ext.APIBase == "" { - return nil, fmt.Errorf("Invalid API base URI: %s", ext.APIBase) - } - - var ssoService types.SingleSignOnService - foundService := false - for _, service := range meta.IDPSSODescriptor.SingleSignOnServices { - if service.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" { - ssoService = service - foundService = true - break - } - } - if !foundService { - return nil, errors.New("No valid SSO service found in IDP metadata") - } - - certStore := dsig.MemoryX509CertificateStore{ - Roots: []*x509.Certificate{}, - } - - for _, kd := range meta.IDPSSODescriptor.KeyDescriptors { - for _, xcert := range kd.KeyInfo.X509Data.X509Certificates { - if xcert.Data == "" { - continue - } - certData, err := base64.StdEncoding.DecodeString(xcert.Data) - if err != nil { - continue - } - - idpCert, err := x509.ParseCertificate(certData) - if err != nil { - continue - } - - certStore.Roots = append(certStore.Roots, idpCert) - } - } - - keyStore := &ConfigX509KeyStore{ - InstanceID: instanceId, - DB: db, - Conf: ext, - } - - sp := &saml2.SAMLServiceProvider{ - IdentityProviderSSOURL: ssoService.Location, - IdentityProviderIssuer: meta.EntityID, - AssertionConsumerServiceURL: baseURI.String() + "/saml/acs", - ServiceProviderIssuer: baseURI.String() + "/saml", - SignAuthnRequests: true, - AudienceURI: baseURI.String() + "/saml", - IDPCertificateStore: &certStore, - SPKeyStore: keyStore, - AllowMissingAttributes: true, - } - - p := &SamlProvider{ - ServiceProvider: sp, - } - return p, nil -} - -func (p SamlProvider) AuthCodeURL(tokenString string, args ...oauth2.AuthCodeOption) string { - url, err := p.ServiceProvider.BuildAuthURL(tokenString) - if err != nil { - return "" - } - return url -} - -func (p SamlProvider) SPMetadata() ([]byte, error) { - metadata, err := p.ServiceProvider.Metadata() - if err != nil { - return nil, err - } - - // the typing for encryption methods currently causes the xml to violate the spec - // therefore they are removed since they are optional anyways and mostly unused - metadata.SPSSODescriptor.KeyDescriptors[1].EncryptionMethods = []types.EncryptionMethod{} - - rawMetadata, err := xml.Marshal(metadata) - if err != nil { - return nil, err - } - - return rawMetadata, nil -} - -func (ks ConfigX509KeyStore) GetKeyPair() (*rsa.PrivateKey, []byte, error) { - if ks.Conf.SigningCert == "" && ks.Conf.SigningKey == "" { - return ks.CreateSigningCert() - } - - keyPair, err := tls.X509KeyPair([]byte(ks.Conf.SigningCert), []byte(ks.Conf.SigningKey)) - if err != nil { - return nil, nil, fmt.Errorf("Parsing key pair failed: %+v", err) - } - - var privKey *rsa.PrivateKey - switch key := keyPair.PrivateKey.(type) { - case *rsa.PrivateKey: - privKey = key - default: - return nil, nil, errors.New("Private key is not an RSA key") - } - - return privKey, keyPair.Certificate[0], nil -} - -func (ks ConfigX509KeyStore) CreateSigningCert() (*rsa.PrivateKey, []byte, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - currentTime := time.Now() - - certBody := &x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: currentTime.Add(-5 * time.Minute), - NotAfter: currentTime.Add(365 * 24 * time.Hour), - - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{}, - BasicConstraintsValid: true, - } - - cert, err := x509.CreateCertificate(rand.Reader, certBody, certBody, &key.PublicKey, key) - if err != nil { - return nil, nil, fmt.Errorf("Failed to create certificate: %+v", err) - } - - if err := ks.SaveConfig(cert, key); err != nil { - return nil, nil, fmt.Errorf("Saving signing keypair failed: %+v", err) - } - - return key, cert, nil -} - -func (ks ConfigX509KeyStore) SaveConfig(cert []byte, key *rsa.PrivateKey) error { - if ks.InstanceID == uuid.Nil { - return nil - } - - pemCert := &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert, - } - - certBytes := pem.EncodeToMemory(pemCert) - if certBytes == nil { - return errors.New("Could not encode certificate") - } - - pemKey := &pem.Block{ - Type: "PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - } - - keyBytes := pem.EncodeToMemory(pemKey) - if keyBytes == nil { - return errors.New("Could not encode key") - } - - instance, err := models.GetInstance(ks.DB, ks.InstanceID) - if err != nil { - return err - } - - conf := instance.BaseConfig - conf.External.Saml.SigningCert = string(certBytes) - conf.External.Saml.SigningKey = string(keyBytes) - - if err := instance.UpdateConfig(ks.DB, conf); err != nil { - return err - } - - return nil -} diff --git a/api/recover.go b/api/recover.go deleted file mode 100644 index 05de55d25..000000000 --- a/api/recover.go +++ /dev/null @@ -1,59 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "net/http" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// RecoverParams holds the parameters for a password recovery request -type RecoverParams struct { - Email string `json:"email"` -} - -// Recover sends a recovery email -func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - instanceID := getInstanceID(ctx) - params := &RecoverParams{} - jsonDecoder := json.NewDecoder(r.Body) - err := jsonDecoder.Decode(params) - if err != nil { - return badRequestError("Could not read verification params: %v", err) - } - - if params.Email == "" { - return unprocessableEntityError("Password recovery requires an email") - } - - aud := a.requestAud(ctx, r) - user, err := models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); terr != nil { - return terr - } - - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer) - }) - if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") - } - return internalServerError("Error recovering user").WithInternalError(err) - } - - return sendJSON(w, http.StatusOK, &map[string]string{}) -} diff --git a/api/settings.go b/api/settings.go deleted file mode 100644 index c53f751de..000000000 --- a/api/settings.go +++ /dev/null @@ -1,70 +0,0 @@ -package api - -import "net/http" - -type ProviderSettings struct { - Apple bool `json:"apple"` - Azure bool `json:"azure"` - Bitbucket bool `json:"bitbucket"` - Discord bool `json:"discord"` - GitHub bool `json:"github"` - GitLab bool `json:"gitlab"` - Google bool `json:"google"` - Linkedin bool `json:"linkedin"` - Facebook bool `json:"facebook"` - Notion bool `json:"notion"` - Spotify bool `json:"spotify"` - Slack bool `json:"slack"` - Twitch bool `json:"twitch"` - Twitter bool `json:"twitter"` - Email bool `json:"email"` - Phone bool `json:"phone"` - SAML bool `json:"saml"` -} - -type ProviderLabels struct { - SAML string `json:"saml,omitempty"` -} - -type Settings struct { - ExternalProviders ProviderSettings `json:"external"` - ExternalLabels ProviderLabels `json:"external_labels"` - DisableSignup bool `json:"disable_signup"` - MailerAutoconfirm bool `json:"mailer_autoconfirm"` - PhoneAutoconfirm bool `json:"phone_autoconfirm"` - SmsProvider string `json:"sms_provider"` -} - -func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { - config := a.getConfig(r.Context()) - - return sendJSON(w, http.StatusOK, &Settings{ - ExternalProviders: ProviderSettings{ - Apple: config.External.Apple.Enabled, - Azure: config.External.Azure.Enabled, - Bitbucket: config.External.Bitbucket.Enabled, - Discord: config.External.Discord.Enabled, - GitHub: config.External.Github.Enabled, - GitLab: config.External.Gitlab.Enabled, - Google: config.External.Google.Enabled, - Linkedin: config.External.Linkedin.Enabled, - Facebook: config.External.Facebook.Enabled, - Notion: config.External.Notion.Enabled, - Spotify: config.External.Spotify.Enabled, - Slack: config.External.Slack.Enabled, - Twitch: config.External.Twitch.Enabled, - Twitter: config.External.Twitter.Enabled, - Email: config.External.Email.Enabled, - Phone: config.External.Phone.Enabled, - SAML: config.External.Saml.Enabled, - }, - ExternalLabels: ProviderLabels{ - SAML: config.External.Saml.Name, - }, - - DisableSignup: config.DisableSignup, - MailerAutoconfirm: config.Mailer.Autoconfirm, - PhoneAutoconfirm: config.Sms.Autoconfirm, - SmsProvider: config.Sms.Provider, - }) -} diff --git a/api/signup.go b/api/signup.go deleted file mode 100644 index ae4ff92ad..000000000 --- a/api/signup.go +++ /dev/null @@ -1,309 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/metering" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -// SignupParams are the parameters the Signup endpoint accepts -type SignupParams struct { - Email string `json:"email"` - Phone string `json:"phone"` - Password string `json:"password"` - Data map[string]interface{} `json:"data"` - Provider string `json:"-"` - Aud string `json:"-"` -} - -// Signup is the endpoint for registering a new user -func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") - } - - params := &SignupParams{} - jsonDecoder := json.NewDecoder(r.Body) - err := jsonDecoder.Decode(params) - if err != nil { - return badRequestError("Could not read Signup params: %v", err) - } - - if params.Password == "" { - return unprocessableEntityError("Signup requires a valid password") - } - if len(params.Password) < config.PasswordMinLength { - return unprocessableEntityError(fmt.Sprintf("Password should be at least %d characters", config.PasswordMinLength)) - } - if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on signup.") - } - if params.Email != "" { - params.Provider = "email" - } else if params.Phone != "" { - params.Provider = "phone" - } - if params.Data == nil { - params.Data = make(map[string]interface{}) - } - - var user *models.User - instanceID := getInstanceID(ctx) - params.Aud = a.requestAud(ctx, r) - - switch params.Provider { - case "email": - if !config.External.Email.Enabled { - return badRequestError("Email signups are disabled") - } - if err := a.validateEmail(ctx, params.Email); err != nil { - return err - } - user, err = models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, params.Aud) - case "phone": - if !config.External.Phone.Enabled { - return badRequestError("Phone signups are disabled") - } - params.Phone = a.formatPhoneNumber(params.Phone) - if isValid := a.validateE164Format(params.Phone); !isValid { - return unprocessableEntityError("Invalid phone number format") - } - user, err = models.FindUserByPhoneAndAudience(a.db, instanceID, params.Phone, params.Aud) - default: - return invalidSignupError(config) - } - - if err != nil && !models.IsNotFoundError(err) { - return internalServerError("Database error finding user").WithInternalError(err) - } - - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if user != nil { - if (params.Provider == "email" && user.IsConfirmed()) || (params.Provider == "phone" && user.IsPhoneConfirmed()) { - return UserExistsError - } - - if err := user.UpdateUserMetaData(tx, params.Data); err != nil { - return internalServerError("Database error updating user").WithInternalError(err) - } - } else { - user, terr = a.signupNewUser(ctx, tx, params) - if terr != nil { - return terr - } - identity, terr := a.createNewIdentity(tx, user, params.Provider, map[string]interface{}{"sub": user.ID.String()}) - if terr != nil { - return terr - } - user.Identities = []models.Identity{*identity} - } - - if params.Provider == "email" && !user.IsConfirmed() { - if config.Mailer.Autoconfirm { - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - if terr = user.Confirm(tx); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) - } - } else { - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserConfirmationRequestedAction, nil); terr != nil { - return terr - } - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - now := time.Now() - left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second - return tooManyRequestsError(fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) - } - return internalServerError("Error sending confirmation mail").WithInternalError(terr) - } - } - } else if params.Provider == "phone" && !user.IsPhoneConfirmed() { - if config.Sms.Autoconfirm { - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - if terr = user.ConfirmPhone(tx); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) - } - } else { - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserConfirmationRequestedAction, nil); terr != nil { - return terr - } - if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone); terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) - } - } - } - - return nil - }) - - if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") - } - if errors.Is(err, UserExistsError) { - err = a.db.Transaction(func(tx *storage.Connection) error { - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.UserRepeatedSignUpAction, nil); terr != nil { - return terr - } - return nil - }) - if err != nil { - return err - } - if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return badRequestError("User already registered") - } - sanitizedUser, err := sanitizeUser(user, params) - if err != nil { - return err - } - return sendJSON(w, http.StatusOK, sanitizedUser) - } - return err - } - - // handles case where Mailer.Autoconfirm is true or Phone.Autoconfirm is true - if user.IsConfirmed() || user.IsPhoneConfirmed() { - var token *AccessTokenResponse - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.LoginAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, LoginEvent, user, instanceID, config); terr != nil { - return terr - } - - token, terr = a.issueRefreshToken(ctx, tx, user) - if terr != nil { - return terr - } - - if terr = a.setCookieTokens(config, token, false, w); terr != nil { - return internalServerError("Failed to set JWT cookie. %s", terr) - } - return nil - }) - if err != nil { - return err - } - metering.RecordLogin("password", user.ID, instanceID) - token.User = user - return sendJSON(w, http.StatusOK, token) - } - - return sendJSON(w, http.StatusOK, user) -} - -// sanitizeUser removes all user sensitive information from the user object -// Should be used whenever we want to prevent information about whether a user is registered or not from leaking -func sanitizeUser(u *models.User, params *SignupParams) (*models.User, error) { - var err error - now := time.Now() - - u.ID, err = uuid.NewV4() - if err != nil { - return nil, errors.Wrap(err, "Error generating unique id") - } - u.CreatedAt, u.UpdatedAt, u.ConfirmationSentAt = now, now, &now - u.LastSignInAt, u.ConfirmedAt, u.EmailConfirmedAt, u.PhoneConfirmedAt = nil, nil, nil, nil - u.Identities = make([]models.Identity, 0) - u.UserMetaData = params.Data - u.Aud = params.Aud - - // sanitize app_metadata - u.AppMetaData = map[string]interface{}{ - "provider": params.Provider, - "providers": []string{params.Provider}, - } - - // sanitize param fields - switch params.Provider { - case "email": - u.Phone = "" - case "phone": - u.Email = "" - default: - u.Phone, u.Email = "", "" - } - - return u, nil -} - -func (a *API) signupNewUser(ctx context.Context, conn *storage.Connection, params *SignupParams) (*models.User, error) { - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - - var user *models.User - var err error - switch params.Provider { - case "email": - user, err = models.NewUser(instanceID, params.Email, params.Password, params.Aud, params.Data) - case "phone": - user, err = models.NewUser(instanceID, "", params.Password, params.Aud, params.Data) - user.Phone = storage.NullString(params.Phone) - default: - // handles external provider case - user, err = models.NewUser(instanceID, params.Email, params.Password, params.Aud, params.Data) - } - - if err != nil { - return nil, internalServerError("Database error creating user").WithInternalError(err) - } - if user.AppMetaData == nil { - user.AppMetaData = make(map[string]interface{}) - } - - user.Identities = make([]models.Identity, 0) - - // TODO: Deprecate "provider" field - user.AppMetaData["provider"] = params.Provider - - user.AppMetaData["providers"] = []string{params.Provider} - if params.Password == "" { - user.EncryptedPassword = "" - } - - err = conn.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = tx.Create(user); terr != nil { - return internalServerError("Database error saving new user").WithInternalError(terr) - } - if terr = user.SetRole(tx, config.JWT.DefaultGroupName); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) - } - if terr = triggerEventHooks(ctx, tx, ValidateEvent, user, instanceID, config); terr != nil { - return terr - } - return nil - }) - if err != nil { - return nil, err - } - - return user, nil -} diff --git a/api/signup_test.go b/api/signup_test.go deleted file mode 100644 index f06756bc9..000000000 --- a/api/signup_test.go +++ /dev/null @@ -1,264 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type SignupTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestSignup(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &SignupTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *SignupTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) - ts.Config.Webhook = conf.WebhookConfig{} -} - -// TestSignup tests API /signup route -func (ts *SignupTestSuite) TestSignup() { - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "test123", - "data": map[string]interface{}{ - "a": 1, - }, - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusOK, w.Code) - - data := models.User{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - assert.Equal(ts.T(), "test@example.com", data.GetEmail()) - assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) - assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) - assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) - assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) -} - -func (ts *SignupTestSuite) TestWebhookTriggered() { - var callCount int - require := ts.Require() - assert := ts.Assert() - - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - assert.Equal("application/json", r.Header.Get("Content-Type")) - - // verify the signature - signature := r.Header.Get("x-webhook-signature") - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} - claims := new(jwt.StandardClaims) - token, err := p.ParseWithClaims(signature, claims, func(token *jwt.Token) (interface{}, error) { - return []byte(ts.Config.Webhook.Secret), nil - }) - assert.True(token.Valid) - assert.Equal(ts.instanceID.String(), claims.Subject) // not configured for multitenancy - assert.Equal("gotrue", claims.Issuer) - assert.WithinDuration(time.Now(), time.Unix(claims.IssuedAt, 0), 5*time.Second) - - // verify the contents - - defer squash(r.Body.Close) - raw, err := ioutil.ReadAll(r.Body) - require.NoError(err) - data := map[string]interface{}{} - require.NoError(json.Unmarshal(raw, &data)) - - assert.Equal(3, len(data)) - assert.Equal("validate", data["event"]) - assert.Equal(ts.instanceID.String(), data["instance_id"]) - - u, ok := data["user"].(map[string]interface{}) - require.True(ok) - assert.Len(u, 10) - // assert.Equal(t, user.ID, u["id"]) TODO - assert.Equal("authenticated", u["aud"]) - assert.Equal("authenticated", u["role"]) - assert.Equal("test@example.com", u["email"]) - - appmeta, ok := u["app_metadata"].(map[string]interface{}) - require.True(ok) - assert.Len(appmeta, 2) - assert.EqualValues("email", appmeta["provider"]) - assert.EqualValues([]interface{}{"email"}, appmeta["providers"]) - - usermeta, ok := u["user_metadata"].(map[string]interface{}) - require.True(ok) - assert.Len(usermeta, 1) - assert.EqualValues(1, usermeta["a"]) - })) - defer svr.Close() - - // Allowing connection to localhost for the tests only - localhost := removeLocalhostFromPrivateIPBlock() - defer unshiftPrivateIPBlock(localhost) - - ts.Config.Webhook = conf.WebhookConfig{ - URL: svr.URL, - Retries: 1, - TimeoutSec: 1, - Secret: "top-secret", - Events: []string{"validate"}, - } - var buffer bytes.Buffer - require.NoError(json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "test123", - "data": map[string]interface{}{ - "a": 1, - }, - })) - req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(http.StatusOK, w.Code) - assert.Equal(1, callCount) -} - -func (ts *SignupTestSuite) TestFailingWebhook() { - ts.Config.Webhook = conf.WebhookConfig{ - URL: "http://notaplace.localhost", - Retries: 1, - TimeoutSec: 1, - Events: []string{"validate", "signup"}, - } - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "test123", - "data": map[string]interface{}{ - "a": 1, - }, - })) - req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusBadGateway, w.Code) -} - -// TestSignupTwice checks to make sure the same email cannot be registered twice -func (ts *SignupTestSuite) TestSignupTwice() { - // Request body - var buffer bytes.Buffer - - encode := func() { - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test1@example.com", - "password": "test123", - "data": map[string]interface{}{ - "a": 1, - }, - })) - } - - encode() - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - y := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(y, req) - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test1@example.com", ts.Config.JWT.Aud) - if err == nil { - require.NoError(ts.T(), u.Confirm(ts.API.db)) - } - - encode() - ts.API.handler.ServeHTTP(w, req) - - data := models.User{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - - require.Equal(ts.T(), http.StatusOK, w.Code) - - assert.NotEqual(ts.T(), u.ID, data.ID) - assert.Equal(ts.T(), "test1@example.com", data.GetEmail()) - assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) - assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) - assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) - assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) -} - -func (ts *SignupTestSuite) TestVerifySignup() { - user, err := models.NewUser(ts.instanceID, "test@example.com", "testing", ts.Config.JWT.Aud, nil) - user.ConfirmationToken = "asdf3" - now := time.Now() - user.ConfirmationSentAt = &now - require.NoError(ts.T(), err) - require.NoError(ts.T(), ts.API.db.Create(user)) - - // Find test user - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "type": "signup", - "token": u.ConfirmationToken, - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - assert.Equal(ts.T(), http.StatusOK, w.Code, w.Body.String()) -} diff --git a/api/sms_provider/sms_provider.go b/api/sms_provider/sms_provider.go deleted file mode 100644 index 11b3b3f2b..000000000 --- a/api/sms_provider/sms_provider.go +++ /dev/null @@ -1,26 +0,0 @@ -package sms_provider - -import ( - "fmt" - - "github.com/netlify/gotrue/conf" -) - -type SmsProvider interface { - SendSms(phone, message string) error -} - -func GetSmsProvider(config conf.Configuration) (SmsProvider, error) { - switch name := config.Sms.Provider; name { - case "twilio": - return NewTwilioProvider(config.Sms.Twilio) - case "messagebird": - return NewMessagebirdProvider(config.Sms.Messagebird) - case "textlocal": - return NewTextlocalProvider(config.Sms.Textlocal) - case "vonage": - return NewVonageProvider(config.Sms.Vonage) - default: - return nil, fmt.Errorf("Sms Provider %s could not be found", name) - } -} diff --git a/api/sms_provider/textlocal.go b/api/sms_provider/textlocal.go deleted file mode 100644 index c45cab703..000000000 --- a/api/sms_provider/textlocal.go +++ /dev/null @@ -1,83 +0,0 @@ -package sms_provider - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "strings" - - "github.com/netlify/gotrue/conf" -) - -const ( - defaultTextLocalApiBase = "https://api.textlocal.in" -) - -type TextlocalProvider struct { - Config *conf.TextlocalProviderConfiguration - APIPath string -} - -type TextlocalError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type TextlocalResponse struct { - Status string `json:"status"` - Errors []TextlocalError `json:"errors"` -} - -// Creates a SmsProvider with the Textlocal Config -func NewTextlocalProvider(config conf.TextlocalProviderConfiguration) (SmsProvider, error) { - if err := config.Validate(); err != nil { - return nil, err - } - - apiPath := defaultTextLocalApiBase + "/send" - return &TextlocalProvider{ - Config: &config, - APIPath: apiPath, - }, nil -} - -// Send an SMS containing the OTP with Textlocal's API -func (t TextlocalProvider) SendSms(phone string, message string) error { - body := url.Values{ - "sender": {t.Config.Sender}, - "apikey": {t.Config.ApiKey}, - "message": {message}, - "numbers": {phone}, - } - - client := &http.Client{} - r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) - if err != nil { - return err - } - - r.Header.Add("Content-Type", "application/x-www-form-urlencoded") - res, err := client.Do(r) - if err != nil { - return err - } - defer res.Body.Close() - - resp := &TextlocalResponse{} - derr := json.NewDecoder(res.Body).Decode(resp) - if derr != nil { - return derr - } - - if len(resp.Errors) == 0 { - return errors.New("Textlocal error: Internal Error") - } - - if resp.Status != "success" { - return fmt.Errorf("Textlocal error: %v (code: %v)", resp.Errors[0].Message, resp.Errors[0].Code) - } - - return nil -} diff --git a/api/sms_provider/twilio.go b/api/sms_provider/twilio.go deleted file mode 100644 index c9f4cc6fe..000000000 --- a/api/sms_provider/twilio.go +++ /dev/null @@ -1,97 +0,0 @@ -package sms_provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - - "github.com/netlify/gotrue/conf" -) - -const ( - defaultTwilioApiBase = "https://api.twilio.com" - apiVersion = "2010-04-01" -) - -type TwilioProvider struct { - Config *conf.TwilioProviderConfiguration - APIPath string -} - -type SmsStatus struct { - To string `json:"to"` - From string `json:"from"` - Status string `json:"status"` - ErrorCode string `json:"error_code"` - ErrorMessage string `json:"error_message"` - Body string `json:"body"` -} - -type twilioErrResponse struct { - Code int `json:"code"` - Message string `json:"message"` - MoreInfo string `json:"more_info"` - Status int `json:"status"` -} - -func (t twilioErrResponse) Error() string { - return fmt.Sprintf("%s More information: %s", t.Message, t.MoreInfo) -} - -// Creates a SmsProvider with the Twilio Config -func NewTwilioProvider(config conf.TwilioProviderConfiguration) (SmsProvider, error) { - if err := config.Validate(); err != nil { - return nil, err - } - - apiPath := defaultTwilioApiBase + "/" + apiVersion + "/" + "Accounts" + "/" + config.AccountSid + "/Messages.json" - return &TwilioProvider{ - Config: &config, - APIPath: apiPath, - }, nil -} - -// Send an SMS containing the OTP with Twilio's API -func (t TwilioProvider) SendSms(phone string, message string) error { - body := url.Values{ - "To": {"+" + phone}, // twilio api requires "+" extension to be included - "Channel": {"sms"}, - "From": {t.Config.MessageServiceSid}, - "Body": {message}, - } - - client := &http.Client{} - r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) - if err != nil { - return err - } - r.Header.Add("Content-Type", "application/x-www-form-urlencoded") - r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) - res, err := client.Do(r) - if err != nil { - return err - } - if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusForbidden { - resp := &twilioErrResponse{} - if err := json.NewDecoder(res.Body).Decode(resp); err != nil { - return err - } - return resp - } - defer res.Body.Close() - - // validate sms status - resp := &SmsStatus{} - derr := json.NewDecoder(res.Body).Decode(resp) - if derr != nil { - return derr - } - - if resp.Status == "failed" || resp.Status == "undelivered" { - return fmt.Errorf("Twilio error: %v %v", resp.ErrorMessage, resp.ErrorCode) - } - - return nil -} diff --git a/api/testdata/saml-idp-metadata.xml b/api/testdata/saml-idp-metadata.xml deleted file mode 100644 index 5ee3bd12b..000000000 --- a/api/testdata/saml-idp-metadata.xml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - {{.Cert}} - - - - urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress - - - - diff --git a/api/testdata/saml-response.xml b/api/testdata/saml-response.xml deleted file mode 100644 index dffe20ea5..000000000 --- a/api/testdata/saml-response.xml +++ /dev/null @@ -1,26 +0,0 @@ - - - https://idp/saml2test - - - - - https://idp/saml2test - - saml@example.com - - - - - - - http://localhost/saml - - - - - urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified - - - - \ No newline at end of file diff --git a/api/token.go b/api/token.go deleted file mode 100644 index 8821bb8c1..000000000 --- a/api/token.go +++ /dev/null @@ -1,578 +0,0 @@ -package api - -import ( - "context" - "crypto/sha256" - "encoding/json" - "errors" - "fmt" - "net/http" - "strconv" - "time" - - "github.com/coreos/go-oidc/v3/oidc" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/metering" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// GoTrueClaims is a struct thats used for JWT claims -type GoTrueClaims struct { - jwt.StandardClaims - Email string `json:"email"` - Phone string `json:"phone"` - AppMetaData map[string]interface{} `json:"app_metadata"` - UserMetaData map[string]interface{} `json:"user_metadata"` - Role string `json:"role"` -} - -// AccessTokenResponse represents an OAuth2 success response -type AccessTokenResponse struct { - Token string `json:"access_token"` - TokenType string `json:"token_type"` // Bearer - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - User *models.User `json:"user"` -} - -// PasswordGrantParams are the parameters the ResourceOwnerPasswordGrant method accepts -type PasswordGrantParams struct { - Email string `json:"email"` - Phone string `json:"phone"` - Password string `json:"password"` -} - -// RefreshTokenGrantParams are the parameters the RefreshTokenGrant method accepts -type RefreshTokenGrantParams struct { - RefreshToken string `json:"refresh_token"` -} - -// IdTokenGrantParams are the parameters the IdTokenGrant method accepts -type IdTokenGrantParams struct { - IdToken string `json:"id_token"` - Nonce string `json:"nonce"` - Provider string `json:"provider"` -} - -const useCookieHeader = "x-use-cookie" -const useSessionCookie = "session" -const InvalidLoginMessage = "Invalid login credentials" - -func (p *IdTokenGrantParams) getVerifier(ctx context.Context) (*oidc.IDTokenVerifier, error) { - config := getConfig(ctx) - - var provider *oidc.Provider - var err error - var oAuthProvider conf.OAuthProviderConfiguration - var oAuthProviderClientId string - switch p.Provider { - case "apple": - oAuthProvider = config.External.Apple - oAuthProviderClientId = config.External.IosBundleId - provider, err = oidc.NewProvider(ctx, "https://appleid.apple.com") - case "azure": - oAuthProvider = config.External.Azure - oAuthProviderClientId = oAuthProvider.ClientID - provider, err = oidc.NewProvider(ctx, "https://login.microsoftonline.com/common/v2.0") - case "facebook": - oAuthProvider = config.External.Facebook - oAuthProviderClientId = oAuthProvider.ClientID - provider, err = oidc.NewProvider(ctx, "https://www.facebook.com") - case "google": - oAuthProvider = config.External.Google - oAuthProviderClientId = oAuthProvider.ClientID - provider, err = oidc.NewProvider(ctx, "https://accounts.google.com") - default: - return nil, fmt.Errorf("Provider %s doesn't support the id_token grant flow", p.Provider) - } - - if err != nil { - return nil, err - } - - if !oAuthProvider.Enabled { - return nil, badRequestError("Provider is not enabled") - } - - return provider.Verifier(&oidc.Config{ClientID: oAuthProviderClientId}), nil -} - -func getEmailVerified(v interface{}) bool { - var emailVerified bool - var err error - switch v.(type) { - case string: - emailVerified, err = strconv.ParseBool(v.(string)) - case bool: - emailVerified = v.(bool) - default: - emailVerified = false - } - if err != nil { - return false - } - return emailVerified -} - -// Token is the endpoint for OAuth access token requests -func (a *API) Token(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - grantType := r.FormValue("grant_type") - - switch grantType { - case "password": - return a.ResourceOwnerPasswordGrant(ctx, w, r) - case "refresh_token": - return a.RefreshTokenGrant(ctx, w, r) - case "id_token": - return a.IdTokenGrant(ctx, w, r) - default: - return oauthError("unsupported_grant_type", "") - } -} - -// ResourceOwnerPasswordGrant implements the password grant type flow -func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - params := &PasswordGrantParams{} - - jsonDecoder := json.NewDecoder(r.Body) - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read password grant params: %v", err) - } - - aud := a.requestAud(ctx, r) - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - - if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on login.") - } - var user *models.User - var err error - if params.Email != "" { - if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") - } - user, err = models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) - } else if params.Phone != "" { - if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") - } - params.Phone = a.formatPhoneNumber(params.Phone) - user, err = models.FindUserByPhoneAndAudience(a.db, instanceID, params.Phone, aud) - } else { - return oauthError("invalid_grant", InvalidLoginMessage) - } - - if err != nil { - if models.IsNotFoundError(err) { - return oauthError("invalid_grant", InvalidLoginMessage) - } - return internalServerError("Database error querying schema").WithInternalError(err) - } - - if user.IsBanned() || !user.Authenticate(params.Password) { - return oauthError("invalid_grant", InvalidLoginMessage) - } - - if params.Email != "" && !user.IsConfirmed() { - return oauthError("invalid_grant", "Email not confirmed") - } else if params.Phone != "" && !user.IsPhoneConfirmed() { - return oauthError("invalid_grant", "Phone not confirmed") - } - - var token *AccessTokenResponse - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.LoginAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, LoginEvent, user, instanceID, config); terr != nil { - return terr - } - - token, terr = a.issueRefreshToken(ctx, tx, user) - if terr != nil { - return terr - } - - if terr = a.setCookieTokens(config, token, false, w); terr != nil { - return internalServerError("Failed to set JWT cookie. %s", terr) - } - return nil - }) - if err != nil { - return err - } - metering.RecordLogin("password", user.ID, instanceID) - token.User = user - return sendJSON(w, http.StatusOK, token) -} - -// RefreshTokenGrant implements the refresh_token grant type flow -func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - config := a.getConfig(ctx) - instanceID := getInstanceID(ctx) - - params := &RefreshTokenGrantParams{} - - jsonDecoder := json.NewDecoder(r.Body) - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read refresh token grant params: %v", err) - } - - if params.RefreshToken == "" { - return oauthError("invalid_request", "refresh_token required") - } - - user, token, err := models.FindUserWithRefreshToken(a.db, params.RefreshToken) - if err != nil { - if models.IsNotFoundError(err) { - return oauthError("invalid_grant", "Invalid Refresh Token") - } - return internalServerError(err.Error()) - } - - if user.IsBanned() { - return oauthError("invalid_grant", "Invalid Refresh Token") - } - - if !(config.External.Email.Enabled && config.External.Phone.Enabled) { - providers, err := models.FindProvidersByUser(a.db, user) - if err != nil { - return internalServerError(err.Error()) - } - for _, provider := range providers { - if provider == "email" && !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") - } - if provider == "phone" && !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") - } - } - } - - if token.Revoked { - a.clearCookieTokens(config, w) - if config.Security.RefreshTokenRotationEnabled { - // Revoke all tokens in token family - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = models.RevokeTokenFamily(tx, token); terr != nil { - return terr - } - return nil - }) - if err != nil { - return internalServerError(err.Error()) - } - } - return oauthError("invalid_grant", "Invalid Refresh Token").WithInternalMessage("Possible abuse attempt: %v", r) - } - - var tokenString string - var newToken *models.RefreshToken - var newTokenResponse *AccessTokenResponse - - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.TokenRefreshedAction, nil); terr != nil { - return terr - } - - newToken, terr = models.GrantRefreshTokenSwap(tx, user, token) - if terr != nil { - return internalServerError(terr.Error()) - } - - tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) - if terr != nil { - return internalServerError("error generating jwt token").WithInternalError(terr) - } - - newTokenResponse = &AccessTokenResponse{ - Token: tokenString, - TokenType: "bearer", - ExpiresIn: config.JWT.Exp, - RefreshToken: newToken.Token, - User: user, - } - if terr = a.setCookieTokens(config, newTokenResponse, false, w); terr != nil { - return internalServerError("Failed to set JWT cookie. %s", terr) - } - - return nil - }) - if err != nil { - return err - } - metering.RecordLogin("token", user.ID, instanceID) - return sendJSON(w, http.StatusOK, newTokenResponse) -} - -// IdTokenGrant implements the id_token grant type flow -func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - config := a.getConfig(ctx) - instanceID := getInstanceID(ctx) - - params := &IdTokenGrantParams{} - - jsonDecoder := json.NewDecoder(r.Body) - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read id token grant params: %v", err) - } - - if params.IdToken == "" || params.Nonce == "" || params.Provider == "" { - return oauthError("invalid request", "id_token, nonce and provider required") - } - - verifier, err := params.getVerifier(ctx) - if err != nil { - return err - } - - idToken, err := verifier.Verify(ctx, params.IdToken) - if err != nil { - return badRequestError("%v", err) - } - - claims := make(map[string]interface{}) - if err := idToken.Claims(&claims); err != nil { - return err - } - - // verify nonce to mitigate replay attacks - hashedNonce, ok := claims["nonce"] - if !ok { - return oauthError("invalid request", "missing nonce in id_token") - } - hash := fmt.Sprintf("%x", sha256.Sum256([]byte(params.Nonce))) - if hash != hashedNonce.(string) { - return oauthError("invalid nonce", "").WithInternalMessage("Possible abuse attempt: %v", r) - } - - sub, ok := claims["sub"].(string) - if !ok { - return oauthError("invalid request", "missing sub claim in id_token") - } - - email, ok := claims["email"].(string) - if !ok { - email = "" - } - - var user *models.User - var token *AccessTokenResponse - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - var identity *models.Identity - - if identity, terr = models.FindIdentityByIdAndProvider(tx, sub, params.Provider); terr != nil { - // create new identity & user if identity is not found - if models.IsNotFoundError(terr) { - if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") - } - aud := a.requestAud(ctx, r) - signupParams := &SignupParams{ - Provider: params.Provider, - Email: email, - Aud: aud, - Data: claims, - } - - user, terr = a.signupNewUser(ctx, tx, signupParams) - if terr != nil { - return terr - } - if identity, terr = a.createNewIdentity(tx, user, params.Provider, claims); terr != nil { - return terr - } - } else { - return terr - } - } else { - user, terr = models.FindUserByID(tx, identity.UserID) - if terr != nil { - return terr - } - if email != "" { - identity.IdentityData["email"] = email - } - if user.IsBanned() { - return oauthError("invalid_grant", "invalid id token grant") - } - if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil { - return terr - } - if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { - return terr - } - } - - if !user.IsConfirmed() { - isEmailVerified := false - emailVerified, ok := claims["email_verified"] - if ok { - isEmailVerified = getEmailVerified(emailVerified) - } - if (!ok || !isEmailVerified) && !config.Mailer.Autoconfirm { - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer); terr != nil { - return internalServerError("Error sending confirmation mail").WithInternalError(terr) - } - return unauthorizedError("Error unverified email") - } - - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - - if terr = user.Confirm(tx); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) - } - } else { - if terr := models.NewAuditLogEntry(tx, instanceID, user, models.LoginAction, nil); terr != nil { - return terr - } - if terr = triggerEventHooks(ctx, tx, LoginEvent, user, instanceID, config); terr != nil { - return terr - } - } - - token, terr = a.issueRefreshToken(ctx, tx, user) - if terr != nil { - return oauthError("server_error", terr.Error()) - } - return nil - }) - - if err != nil { - return err - } - - if err := a.setCookieTokens(config, token, false, w); err != nil { - return internalServerError("Failed to set JWT cookie. %s", err) - } - - metering.RecordLogin("id_token", user.ID, instanceID) - return sendJSON(w, http.StatusOK, &AccessTokenResponse{ - Token: token.Token, - TokenType: token.TokenType, - ExpiresIn: token.ExpiresIn, - RefreshToken: token.RefreshToken, - User: user, - }) -} - -func generateAccessToken(user *models.User, expiresIn time.Duration, secret string) (string, error) { - claims := &GoTrueClaims{ - StandardClaims: jwt.StandardClaims{ - Subject: user.ID.String(), - Audience: user.Aud, - ExpiresAt: time.Now().Add(expiresIn).Unix(), - }, - Email: user.GetEmail(), - Phone: user.GetPhone(), - AppMetaData: user.AppMetaData, - UserMetaData: user.UserMetaData, - Role: user.Role, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(secret)) -} - -func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User) (*AccessTokenResponse, error) { - config := a.getConfig(ctx) - - now := time.Now() - user.LastSignInAt = &now - - var tokenString string - var refreshToken *models.RefreshToken - - err := conn.Transaction(func(tx *storage.Connection) error { - var terr error - refreshToken, terr = models.GrantAuthenticatedUser(tx, user) - if terr != nil { - return internalServerError("Database error granting user").WithInternalError(terr) - } - - tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) - if terr != nil { - return internalServerError("error generating jwt token").WithInternalError(terr) - } - return nil - }) - if err != nil { - return nil, err - } - - return &AccessTokenResponse{ - Token: tokenString, - TokenType: "bearer", - ExpiresIn: config.JWT.Exp, - RefreshToken: refreshToken.Token, - }, nil -} - -// setCookieTokens sets the access_token & refresh_token in the cookies -func (a *API) setCookieTokens(config *conf.Configuration, token *AccessTokenResponse, session bool, w http.ResponseWriter) error { - // don't need to catch error here since we always set the cookie name - _ = a.setCookieToken(config, "access-token", token.Token, session, w) - _ = a.setCookieToken(config, "refresh-token", token.RefreshToken, session, w) - return nil -} - -func (a *API) setCookieToken(config *conf.Configuration, name string, tokenString string, session bool, w http.ResponseWriter) error { - if name == "" { - return errors.New("Failed to set cookie, invalid name") - } - cookieName := config.Cookie.Key + "-" + name - exp := time.Second * time.Duration(config.Cookie.Duration) - cookie := &http.Cookie{ - Name: cookieName, - Value: tokenString, - Secure: true, - HttpOnly: true, - Path: "/", - Domain: config.Cookie.Domain, - } - if !session { - cookie.Expires = time.Now().Add(exp) - cookie.MaxAge = config.Cookie.Duration - } - - http.SetCookie(w, cookie) - return nil -} - -func (a *API) clearCookieTokens(config *conf.Configuration, w http.ResponseWriter) { - a.clearCookieToken(config, "access-token", w) - a.clearCookieToken(config, "refresh-token", w) -} - -func (a *API) clearCookieToken(config *conf.Configuration, name string, w http.ResponseWriter) { - cookieName := config.Cookie.Key - if name != "" { - cookieName += "-" + name - } - http.SetCookie(w, &http.Cookie{ - Name: cookieName, - Value: "", - Expires: time.Now().Add(-1 * time.Hour * 10), - MaxAge: -1, - Secure: true, - HttpOnly: true, - Path: "/", - Domain: config.Cookie.Domain, - }) -} diff --git a/api/token_test.go b/api/token_test.go deleted file mode 100644 index eaa5dcb5f..000000000 --- a/api/token_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type TokenTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - RefreshToken *models.RefreshToken - instanceID uuid.UUID -} - -func TestToken(t *testing.T) { - os.Setenv("GOTRUE_RATE_LIMIT_HEADER", "My-Custom-Header") - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &TokenTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *TokenTestSuite) SetupTest() { - ts.RefreshToken = nil - models.TruncateAll(ts.API.db) - - // Create user & refresh token - u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error creating test user model") - t := time.Now() - u.EmailConfirmedAt = &t - u.BannedUntil = nil - require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") - - ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u) - require.NoError(ts.T(), err, "Error creating refresh token") -} - -func (ts *TokenTestSuite) TestRateLimitToken() { - var buffer bytes.Buffer - req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("My-Custom-Header", "1.2.3.4") - - // It rate limits after 30 requests - for i := 0; i < 30; i++ { - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusBadRequest, w.Code) - } - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) - - // It ignores X-Forwarded-For by default - req.Header.Set("X-Forwarded-For", "1.1.1.1") - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) - - // It doesn't rate limit a new value for the limited header - req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("My-Custom-Header", "5.6.7.8") - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusBadRequest, w.Code) -} - -func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - "password": "password", - })) - - req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) -} - -func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "refresh_token": ts.RefreshToken.Token, - })) - - req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) -} - -func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() { - u := ts.createBannedUser() - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": u.GetEmail(), - "password": "password", - })) - - req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusBadRequest, w.Code) -} - -func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() { - _ = ts.createBannedUser() - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "refresh_token": ts.RefreshToken.Token, - })) - - req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusBadRequest, w.Code) -} - -func (ts *TokenTestSuite) createBannedUser() *models.User { - u, err := models.NewUser(ts.instanceID, "banned@example.com", "password", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error creating test user model") - t := time.Now() - u.EmailConfirmedAt = &t - t = t.Add(24 * time.Hour) - u.BannedUntil = &t - require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user") - - ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u) - require.NoError(ts.T(), err, "Error creating refresh token") - - return u -} diff --git a/api/tracer.go b/api/tracer.go deleted file mode 100644 index 64652b8a1..000000000 --- a/api/tracer.go +++ /dev/null @@ -1,65 +0,0 @@ -package api - -import ( - "fmt" - "net/http" - "strconv" - - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" - ddtrace_ext "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer" -) - -type tracingResponseWriter struct { - http.ResponseWriter - statusCode int -} - -func newTracingResponseWriter(w http.ResponseWriter) *tracingResponseWriter { - return &tracingResponseWriter{w, http.StatusOK} -} - -func (trw *tracingResponseWriter) WriteHeader(code int) { - trw.statusCode = code - trw.ResponseWriter.WriteHeader(code) -} - -func tracer(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - clientContext, _ := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header)) - span, traceCtx := opentracing.StartSpanFromContext(r.Context(), "http.handler", - ext.RPCServerOption(clientContext), - opentracer.SpanType(ddtrace_ext.AppTypeWeb), - ) - defer span.Finish() - - ext.HTTPMethod.Set(span, r.Method) - ext.HTTPUrl.Set(span, r.URL.Path) - resourceName := r.URL.Path - resourceName = r.Method + " " + resourceName - span.SetTag("resource.name", resourceName) - - if reqID := getRequestID(r.Context()); reqID != "" { - span.SetTag("http.request_id", reqID) - } - - trw := newTracingResponseWriter(w) - next.ServeHTTP(trw, r.WithContext(traceCtx)) - - status := trw.statusCode - - // Setting the status as an int doesn't propogate for use in datadog dashboards, - // so we convert to a string. - span.SetTag(string(ext.HTTPStatusCode), strconv.Itoa(status)) - - if status >= 500 && status < 600 { - ext.Error.Set(span, true) - span.SetTag("error.type", fmt.Sprintf("%d: %s", status, http.StatusText(status))) - span.LogKV( - "event", "error", - "message", fmt.Sprintf("%d: %s", status, http.StatusText(status)), - ) - } - }) -} diff --git a/api/tracer_test.go b/api/tracer_test.go deleted file mode 100644 index 4dcb51193..000000000 --- a/api/tracer_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package api - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/mocktracer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type TracerTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestTracer(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &TracerTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *TracerTestSuite) TestTracer_Spans() { - mt := mocktracer.New() - opentracing.SetGlobalTracer(mt) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "http://localhost/something1", nil) - ts.API.handler.ServeHTTP(w, req) - req = httptest.NewRequest(http.MethodGet, "http://localhost/something2", nil) - ts.API.handler.ServeHTTP(w, req) - - spans := mt.FinishedSpans() - if assert.Equal(ts.T(), 2, len(spans)) { - assert.Equal(ts.T(), "POST", spans[0].Tag("http.method")) - assert.Equal(ts.T(), "/something1", spans[0].Tag("http.url")) - assert.Equal(ts.T(), "POST /something1", spans[0].Tag("resource.name")) - assert.Equal(ts.T(), "404", spans[0].Tag("http.status_code")) - assert.NotEmpty(ts.T(), spans[0].Tag("http.request_id")) - - assert.Equal(ts.T(), "GET", spans[1].Tag("http.method")) - assert.Equal(ts.T(), "/something2", spans[1].Tag("http.url")) - assert.Equal(ts.T(), "GET /something2", spans[1].Tag("resource.name")) - assert.Equal(ts.T(), "404", spans[1].Tag("http.status_code")) - assert.NotEmpty(ts.T(), spans[1].Tag("http.request_id")) - } -} diff --git a/api/user.go b/api/user.go deleted file mode 100644 index 4ef51525a..000000000 --- a/api/user.go +++ /dev/null @@ -1,144 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" -) - -// UserUpdateParams parameters for updating a user -type UserUpdateParams struct { - Email string `json:"email"` - Password *string `json:"password"` - Data map[string]interface{} `json:"data"` - AppData map[string]interface{} `json:"app_metadata,omitempty"` - Phone string `json:"phone"` -} - -// UserGet returns a user -func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - claims := getClaims(ctx) - if claims == nil { - return badRequestError("Could not read claims") - } - - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return badRequestError("Could not read User ID claim") - } - - aud := a.requestAud(ctx, r) - if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") - } - - user, err := models.FindUserByID(a.db, userID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - - return sendJSON(w, http.StatusOK, user) -} - -// UserUpdate updates fields on a user -func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - instanceID := getInstanceID(ctx) - - params := &UserUpdateParams{} - jsonDecoder := json.NewDecoder(r.Body) - err := jsonDecoder.Decode(params) - if err != nil { - return badRequestError("Could not read User Update params: %v", err) - } - - claims := getClaims(ctx) - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return badRequestError("Could not read User ID claim") - } - - user, err := models.FindUserByID(a.db, userID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - - log := getLogEntry(r) - log.Debugf("Checking params for token %v", params) - - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - if params.Password != nil { - if len(*params.Password) < config.PasswordMinLength { - return invalidPasswordLengthError(config) - } - - if terr = user.UpdatePassword(tx, *params.Password); terr != nil { - return internalServerError("Error during password storage").WithInternalError(terr) - } - } - - if params.Data != nil { - if terr = user.UpdateUserMetaData(tx, params.Data); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) - } - } - - if params.AppData != nil { - if !a.isAdmin(ctx, user, config.JWT.Aud) { - return unauthorizedError("Updating app_metadata requires admin privileges") - } - - if terr = user.UpdateAppMetaData(tx, params.AppData); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) - } - } - - if params.Email != "" && params.Email != user.GetEmail() { - if terr = a.validateEmail(ctx, params.Email); terr != nil { - return terr - } - - var exists bool - if exists, terr = models.IsDuplicatedEmail(tx, instanceID, params.Email, user.Aud); terr != nil { - return internalServerError("Database error checking email").WithInternalError(terr) - } else if exists { - return unprocessableEntityError(DuplicateEmailMsg) - } - - mailer := a.Mailer(ctx) - referrer := a.getReferrer(r) - if config.Mailer.SecureEmailChangeEnabled { - if terr = a.sendSecureEmailChange(tx, user, mailer, params.Email, referrer); terr != nil { - return internalServerError("Error sending change email").WithInternalError(terr) - } - } else { - if terr = a.sendEmailChange(tx, user, mailer, params.Email, referrer); terr != nil { - return internalServerError("Error sending change email").WithInternalError(terr) - } - } - } - - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil { - return internalServerError("Error recording audit log entry").WithInternalError(terr) - } - - return nil - }) - if err != nil { - return err - } - - return sendJSON(w, http.StatusOK, user) -} diff --git a/api/user_test.go b/api/user_test.go deleted file mode 100644 index baae1f35b..000000000 --- a/api/user_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type UserTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestUser(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &UserTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *UserTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) - - // Create user - u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil) - require.NoError(ts.T(), err, "Error creating test user model") - require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") -} - -func (ts *UserTestSuite) TestUser_UpdatePassword() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - var cases = []struct { - desc string - update map[string]interface{} - expectedCode int - isAuthenticated bool - }{ - { - "Valid password length", - map[string]interface{}{ - "password": "newpass", - }, - http.StatusOK, - true, - }, - { - "Invalid password length", - map[string]interface{}{ - "password": "", - }, - http.StatusUnprocessableEntity, - false, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.update)) - - req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) - req.Header.Set("Content-Type", "application/json") - - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) - require.NoError(ts.T(), err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), w.Code, c.expectedCode) - - // Request body - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - passwordUpdate, _ := c.update["password"].(string) - require.Equal(ts.T(), c.isAuthenticated, u.Authenticate(passwordUpdate)) - }) - } -} diff --git a/api/verify.go b/api/verify.go deleted file mode 100644 index 8d5172835..000000000 --- a/api/verify.go +++ /dev/null @@ -1,392 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/url" - "strconv" - "time" - - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" - "github.com/sethvargo/go-password/password" -) - -var ( - // indicates that a user should be redirected due to an error - redirectWithQueryError = errors.New("redirect user") -) - -const ( - signupVerification = "signup" - recoveryVerification = "recovery" - inviteVerification = "invite" - magicLinkVerification = "magiclink" - emailChangeVerification = "email_change" - smsVerification = "sms" -) - -const ( - zeroConfirmation int = iota - singleConfirmation -) - -// VerifyParams are the parameters the Verify endpoint accepts -type VerifyParams struct { - Type string `json:"type"` - Token string `json:"token"` - Password string `json:"password"` - Phone string `json:"phone"` - RedirectTo string `json:"redirect_to"` -} - -// Verify exchanges a confirmation or recovery token to a refresh token -func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - config := a.getConfig(ctx) - - params := &VerifyParams{} - - switch r.Method { - // GET only supports signup type - case "GET": - params.Token = r.FormValue("token") - params.Password = "" - params.Type = r.FormValue("type") - params.RedirectTo = a.getRedirectURLOrReferrer(r, r.FormValue("redirect_to")) - case "POST": - jsonDecoder := json.NewDecoder(r.Body) - if err := jsonDecoder.Decode(params); err != nil { - return badRequestError("Could not read verification params: %v", err) - } - params.RedirectTo = a.getRedirectURLOrReferrer(r, params.RedirectTo) - default: - unprocessableEntityError("Sorry, only GET and POST methods are supported.") - } - - if params.Token == "" { - return unprocessableEntityError("Verify requires a token") - } - - var ( - user *models.User - err error - token *AccessTokenResponse - ) - - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - switch params.Type { - case signupVerification, inviteVerification: - user, terr = a.signupVerify(ctx, tx, params) - case recoveryVerification, magicLinkVerification: - user, terr = a.recoverVerify(ctx, tx, params) - case emailChangeVerification: - user, terr = a.emailChangeVerify(ctx, tx, params) - if user == nil && terr == nil { - // when double confirmation is required - rurl := a.prepRedirectURL("Confirmation link accepted. Please proceed to confirm link sent to the other email", params.RedirectTo) - http.Redirect(w, r, rurl, http.StatusSeeOther) - return nil - } - case smsVerification: - if params.Phone == "" { - return unprocessableEntityError("Sms Verification requires a phone number") - } - params.Phone = a.formatPhoneNumber(params.Phone) - if isValid := a.validateE164Format(params.Phone); !isValid { - return unprocessableEntityError("Invalid phone number format") - } - aud := a.requestAud(ctx, r) - user, terr = a.smsVerify(ctx, tx, params, aud) - default: - return unprocessableEntityError("Verify requires a verification type") - } - - if terr != nil { - var e *HTTPError - if errors.As(terr, &e) { - if errors.Is(e.InternalError, redirectWithQueryError) { - rurl := a.prepErrorRedirectURL(e, r, params.RedirectTo) - http.Redirect(w, r, rurl, http.StatusSeeOther) - return nil - } - } - return terr - } - - token, terr = a.issueRefreshToken(ctx, tx, user) - if terr != nil { - return terr - } - - if terr = a.setCookieTokens(config, token, false, w); terr != nil { - return internalServerError("Failed to set JWT cookie. %s", terr) - } - return nil - }) - if err != nil { - return err - } - - // GET requests should return to the app site after confirmation - switch r.Method { - case "GET": - rurl := params.RedirectTo - if token != nil { - q := url.Values{} - q.Set("access_token", token.Token) - q.Set("token_type", token.TokenType) - q.Set("expires_in", strconv.Itoa(token.ExpiresIn)) - q.Set("refresh_token", token.RefreshToken) - q.Set("type", params.Type) - rurl += "#" + q.Encode() - } - http.Redirect(w, r, rurl, http.StatusSeeOther) - case "POST": - return sendJSON(w, http.StatusOK, token) - } - - return nil -} - -func (a *API) signupVerify(ctx context.Context, conn *storage.Connection, params *VerifyParams) (*models.User, error) { - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - - user, err := models.FindUserByConfirmationToken(conn, params.Token) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(redirectWithQueryError) - } - return nil, internalServerError("Database error finding user").WithInternalError(err) - } - - if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalError(redirectWithQueryError) - } - - nextDay := user.ConfirmationSentAt.Add(24 * time.Hour) - if user.ConfirmationSentAt != nil && time.Now().After(nextDay) { - return nil, expiredTokenError("Confirmation token expired").WithInternalError(redirectWithQueryError) - } - - err = conn.Transaction(func(tx *storage.Connection) error { - var terr error - if user.EncryptedPassword == "" { - if user.InvitedAt != nil { - // sign them up with temporary password, and require application - // to present the user with a password set form - password, err := password.Generate(64, 10, 0, false, true) - if err != nil { - internalServerError("error creating user").WithInternalError(err) - } - if terr = user.UpdatePassword(tx, password); terr != nil { - return internalServerError("Error storing password").WithInternalError(terr) - } - } - } - - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - - if terr = user.Confirm(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) - } - return nil - }) - if err != nil { - return nil, err - } - return user, nil -} - -func (a *API) recoverVerify(ctx context.Context, conn *storage.Connection, params *VerifyParams) (*models.User, error) { - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - user, err := models.FindUserByRecoveryToken(conn, params.Token) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(redirectWithQueryError) - } - return nil, internalServerError("Database error finding user").WithInternalError(err) - } - - if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalError(redirectWithQueryError) - } - - nextDay := user.RecoverySentAt.Add(24 * time.Hour) - if user.RecoverySentAt != nil && time.Now().After(nextDay) { - return nil, expiredTokenError("Recovery token expired").WithInternalError(redirectWithQueryError) - } - - err = conn.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = user.Recover(tx); terr != nil { - return terr - } - if !user.IsConfirmed() { - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - if terr = user.Confirm(tx); terr != nil { - return terr - } - } - return nil - }) - - if err != nil { - return nil, internalServerError("Database error updating user").WithInternalError(err) - } - return user, nil -} - -func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, params *VerifyParams, aud string) (*models.User, error) { - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - user, err := models.FindUserByPhoneAndAudience(conn, instanceID, params.Phone, aud) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(redirectWithQueryError) - } - return nil, internalServerError("Database error finding user").WithInternalError(err) - } - - if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalError(redirectWithQueryError) - } - - now := time.Now() - expiresAt := user.ConfirmationSentAt.Add(time.Second * time.Duration(config.Sms.OtpExp)) - - // check if token has expired or is invalid - if isOtpValid := now.Before(expiresAt) && params.Token == user.ConfirmationToken; !isOtpValid { - return nil, expiredTokenError("Otp has expired or is invalid").WithInternalError(redirectWithQueryError) - } - - err = conn.Transaction(func(tx *storage.Connection) error { - var terr error - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserSignedUpAction, nil); terr != nil { - return terr - } - - if terr = triggerEventHooks(ctx, tx, SignupEvent, user, instanceID, config); terr != nil { - return terr - } - - if terr = user.ConfirmPhone(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) - } - return nil - }) - if err != nil { - return nil, err - } - return user, nil -} - -func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string) string { - q := url.Values{} - - log := getLogEntry(r) - log.Error(err.Message) - - if str, ok := oauthErrorMap[err.Code]; ok { - q.Set("error", str) - } - q.Set("error_code", strconv.Itoa(err.Code)) - q.Set("error_description", err.Message) - return rurl + "#" + q.Encode() -} - -func (a *API) prepRedirectURL(message string, rurl string) string { - q := url.Values{} - q.Set("message", message) - return rurl + "#" + q.Encode() -} - -func (a *API) emailChangeVerify(ctx context.Context, conn *storage.Connection, params *VerifyParams) (*models.User, error) { - instanceID := getInstanceID(ctx) - config := a.getConfig(ctx) - user, err := models.FindUserByEmailChangeToken(conn, params.Token) - if err != nil { - if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(redirectWithQueryError) - } - return nil, internalServerError("Database error finding user").WithInternalError(err) - } - - if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalError(redirectWithQueryError) - } - - nextDay := user.EmailChangeSentAt.Add(24 * time.Hour) - if user.EmailChangeSentAt != nil && time.Now().After(nextDay) { - err = a.db.Transaction(func(tx *storage.Connection) error { - user.EmailChangeConfirmStatus = zeroConfirmation - return tx.UpdateOnly(user, "email_change_confirm_status") - }) - if err != nil { - return nil, err - } - return nil, expiredTokenError("Email change token expired").WithInternalError(redirectWithQueryError) - } - - if config.Mailer.SecureEmailChangeEnabled { - if user.EmailChangeConfirmStatus == zeroConfirmation { - err = a.db.Transaction(func(tx *storage.Connection) error { - user.EmailChangeConfirmStatus = singleConfirmation - if params.Token == user.EmailChangeTokenCurrent { - user.EmailChangeTokenCurrent = "" - } else if params.Token == user.EmailChangeTokenNew { - user.EmailChangeTokenNew = "" - } - if terr := tx.UpdateOnly(user, "email_change_confirm_status", "email_change_token_current", "email_change_token_new"); terr != nil { - return terr - } - return nil - }) - if err != nil { - return nil, err - } - return nil, nil - } - } - - // one email is confirmed at this point - err = a.db.Transaction(func(tx *storage.Connection) error { - var terr error - - if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil { - return terr - } - - if terr = triggerEventHooks(ctx, tx, EmailChangeEvent, user, instanceID, config); terr != nil { - return terr - } - - if terr = user.ConfirmEmailChange(tx, zeroConfirmation); terr != nil { - return internalServerError("Error confirm email").WithInternalError(terr) - } - - return nil - }) - if err != nil { - return nil, err - } - - return user, nil -} diff --git a/api/verify_test.go b/api/verify_test.go deleted file mode 100644 index 0e44da19b..000000000 --- a/api/verify_test.go +++ /dev/null @@ -1,459 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type VerifyTestSuite struct { - suite.Suite - API *API - Config *conf.Configuration - - instanceID uuid.UUID -} - -func TestVerify(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() - require.NoError(t, err) - - ts := &VerifyTestSuite{ - API: api, - Config: config, - instanceID: instanceID, - } - defer api.db.Close() - - suite.Run(t, ts) -} - -func (ts *VerifyTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) - - // Create user - u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil) - u.Phone = "12345678" - require.NoError(ts.T(), err, "Error creating test user model") - require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") -} - -func (ts *VerifyTestSuite) TestVerify_PasswordRecovery() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.RecoverySentAt = &time.Time{} - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - // Send Verify request - var vbuffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&vbuffer).Encode(map[string]interface{}{ - "type": "recovery", - "token": u.RecoveryToken, - })) - - req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &vbuffer) - req.Header.Set("Content-Type", "application/json") - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) -} - -func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.ConfirmationToken = "asdf3" - sentTime := time.Now().Add(-48 * time.Hour) - u.ConfirmationSentAt = &sentTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "type": signupVerification, - "token": u.ConfirmationToken, - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - - url, err := w.Result().Location() - require.NoError(ts.T(), err) - assert.Equal(ts.T(), "error_code=410&error_description=Confirmation+token+expired", url.Fragment) -} - -func (ts *VerifyTestSuite) TestInvalidSmsOtp() { - u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "12345678", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.ConfirmationToken = "123456" - sentTime := time.Now().Add(-48 * time.Hour) - u.ConfirmationSentAt = &sentTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - type expected struct { - code int - fragments string - } - - expectedResponse := expected{ - code: http.StatusSeeOther, - fragments: "error_code=410&error_description=Otp+has+expired+or+is+invalid", - } - - cases := []struct { - desc string - sentTime time.Time - body map[string]interface{} - expected - }{ - { - desc: "Expired OTP", - sentTime: time.Now().Add(-48 * time.Hour), - body: map[string]interface{}{ - "type": smsVerification, - "token": u.ConfirmationToken, - "phone": u.GetPhone(), - }, - expected: expectedResponse, - }, - { - desc: "Incorrect OTP", - sentTime: time.Now(), - body: map[string]interface{}{ - "type": smsVerification, - "token": "incorrect_otp", - "phone": u.GetPhone(), - }, - expected: expectedResponse, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - // update token sent time - sentTime = time.Now() - u.ConfirmationSentAt = &c.sentTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), c.expected.code, w.Code) - - url, err := w.Result().Location() - require.NoError(ts.T(), err) - assert.Equal(ts.T(), c.expected.fragments, url.Fragment) - }) - } -} - -func (ts *VerifyTestSuite) TestExpiredRecoveryToken() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.RecoveryToken = "asdf3" - sentTime := time.Now().Add(-48 * time.Hour) - u.RecoverySentAt = &sentTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "type": recoveryVerification, - "token": u.RecoveryToken, - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - - ts.API.handler.ServeHTTP(w, req) - - assert.Equal(ts.T(), http.StatusSeeOther, w.Code, w.Body.String()) -} - -func (ts *VerifyTestSuite) TestVerifyPermitedCustomUri() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.RecoverySentAt = &time.Time{} - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - redirectURL, _ := url.Parse(ts.Config.URIAllowList[0]) - - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, redirectURL.String()) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - rURL, _ := w.Result().Location() - assert.Equal(ts.T(), redirectURL.Hostname(), rURL.Hostname()) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) -} - -func (ts *VerifyTestSuite) TestVerifyNotPermitedCustomUri() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.RecoverySentAt = &time.Time{} - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ - "email": "test@example.com", - })) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - fakeredirectURL, _ := url.Parse("http://custom-url.com") - siteURL, _ := url.Parse(ts.Config.SiteURL) - - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, fakeredirectURL.String()) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - rURL, _ := w.Result().Location() - assert.Equal(ts.T(), siteURL.Hostname(), rURL.Hostname()) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) -} - -func (ts *VerifyTestSuite) TestVerifySignupWithredirectURLContainedPath() { - testCases := []struct { - desc string - siteURL string - uriAllowList []string - requestredirectURL string - expectedredirectURL string - }{ - { - desc: "same site url and redirect url with path", - siteURL: "http://localhost:3000/#/", - uriAllowList: []string{"http://localhost:3000"}, - requestredirectURL: "http://localhost:3000/#/", - expectedredirectURL: "http://localhost:3000/#/", - }, - { - desc: "different site url and redirect url in allow list", - siteURL: "https://someapp-something.codemagic.app/#/", - uriAllowList: []string{"http://localhost:3000"}, - requestredirectURL: "http://localhost:3000", - expectedredirectURL: "http://localhost:3000", - }, - { - desc: "different site url and redirect url not in allow list", - siteURL: "https://someapp-something.codemagic.app/#/", - uriAllowList: []string{"http://localhost:3000"}, - requestredirectURL: "http://localhost:3000/docs", - expectedredirectURL: "https://someapp-something.codemagic.app/#/", - }, - } - - for _, tC := range testCases { - ts.Run(tC.desc, func() { - // prepare test data - ts.Config.SiteURL = tC.siteURL - redirectURL := tC.requestredirectURL - ts.Config.URIAllowList = tC.uriAllowList - - // set verify token to user as it actual do in magic link method - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.ConfirmationToken = "someToken" - sendTime := time.Now().Add(time.Hour) - u.ConfirmationSentAt = &sendTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "signup", u.ConfirmationToken, redirectURL) - req := httptest.NewRequest(http.MethodGet, reqURL, nil) - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - rURL, _ := w.Result().Location() - assert.Contains(ts.T(), rURL.String(), tC.expectedredirectURL) // redirected url starts with per test value - - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) - }) - } -} - -func (ts *VerifyTestSuite) TestVerifyBannedUser() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - u.ConfirmationToken = "confirmation_token" - u.RecoveryToken = "recovery_token" - u.EmailChangeTokenCurrent = "current_email_change_token" - u.EmailChangeTokenNew = "new_email_change_token" - t := time.Now() - u.ConfirmationSentAt = &t - u.RecoverySentAt = &t - u.EmailChangeSentAt = &t - - t = time.Now().Add(24 * time.Hour) - u.BannedUntil = &t - require.NoError(ts.T(), ts.API.db.Update(u)) - - cases := []struct { - desc string - payload *VerifyParams - }{ - { - "Verify banned user on signup", - &VerifyParams{ - Type: "signup", - Token: u.ConfirmationToken, - }, - }, - { - "Verify banned user on invite", - &VerifyParams{ - Type: "invite", - Token: u.ConfirmationToken, - }, - }, - { - "Verify banned phone user on sms", - &VerifyParams{ - Type: "sms", - Token: u.ConfirmationToken, - Phone: u.GetPhone(), - }, - }, - { - "Verify banned user on recover", - &VerifyParams{ - Type: "recovery", - Token: u.RecoveryToken, - }, - }, - { - "Verify banned user on magiclink", - &VerifyParams{ - Type: "magiclink", - Token: u.RecoveryToken, - }, - }, - { - "Verify banned user on email change", - &VerifyParams{ - Type: "email_change", - Token: u.EmailChangeTokenCurrent, - }, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) - - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - - url, err := w.Result().Location() - require.NoError(ts.T(), err) - assert.Equal(ts.T(), "error=unauthorized_client&error_code=401&error_description=Error+confirming+user", url.Fragment) - }) - } -} diff --git a/app.json b/app.json index d82737b7d..486865630 100644 --- a/app.json +++ b/app.json @@ -2,11 +2,11 @@ "name": "Gotrue", "description": "", "website": "https://www.gotrueapi.org", - "repository": "https://github.com/netlify/gotrue", + "repository": "https://github.com/supabase/gotrue", "env": { "DATABASE_URL": {}, "GOTRUE_DB_DRIVER": { - "value": "mysql" + "value": "postgres" }, "GOTRUE_DB_AUTOMIGRATE": { "value": true diff --git a/client/admin/client.go b/client/admin/client.go new file mode 100644 index 000000000..cd5055027 --- /dev/null +++ b/client/admin/client.go @@ -0,0 +1,2727 @@ +// Package admin provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.4.2-0.20250102212541-8bbe226927c9 DO NOT EDIT. +package admin + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/oapi-codegen/runtime" + openapi_types "github.com/oapi-codegen/runtime/types" +) + +const ( + APIKeyAuthScopes = "APIKeyAuth.Scopes" + AdminAuthScopes = "AdminAuth.Scopes" +) + +// Defines values for ErrorSchemaWeakPasswordReasons. +const ( + Characters ErrorSchemaWeakPasswordReasons = "characters" + Length ErrorSchemaWeakPasswordReasons = "length" + Pwned ErrorSchemaWeakPasswordReasons = "pwned" +) + +// Defines values for PostAdminGenerateLinkJSONBodyType. +const ( + EmailChangeCurrent PostAdminGenerateLinkJSONBodyType = "email_change_current" + EmailChangeNew PostAdminGenerateLinkJSONBodyType = "email_change_new" + Magiclink PostAdminGenerateLinkJSONBodyType = "magiclink" + Recovery PostAdminGenerateLinkJSONBodyType = "recovery" + Signup PostAdminGenerateLinkJSONBodyType = "signup" +) + +// Defines values for PostAdminSsoProvidersJSONBodyType. +const ( + Saml PostAdminSsoProvidersJSONBodyType = "saml" +) + +// ErrorSchema defines model for ErrorSchema. +type ErrorSchema struct { + // Code The HTTP status code. Usually missing if `error` is present. + Code *int `json:"code,omitempty"` + + // Error Certain responses will contain this property with the provided values. + // + // Usually one of these: + // - invalid_request + // - unauthorized_client + // - access_denied + // - server_error + // - temporarily_unavailable + // - unsupported_otp_type + Error *string `json:"error,omitempty"` + + // ErrorCode A short code used to describe the class of error encountered. + ErrorCode *string `json:"error_code,omitempty"` + + // ErrorDescription Certain responses that have an `error` property may have this property which describes the error. + ErrorDescription *string `json:"error_description,omitempty"` + + // Msg A basic message describing the problem with the request. Usually missing if `error` is present. + Msg *string `json:"msg,omitempty"` + + // WeakPassword Only returned on the `/signup` endpoint if the password used is too weak. Inspect the `reasons` and `msg` property to identify the causes. + WeakPassword *struct { + Reasons *[]ErrorSchemaWeakPasswordReasons `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` +} + +// ErrorSchemaWeakPasswordReasons defines model for ErrorSchema.WeakPassword.Reasons. +type ErrorSchemaWeakPasswordReasons string + +// IdentitySchema defines model for IdentitySchema. +type IdentitySchema struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Email *openapi_types.Email `json:"email,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + IdentityData *map[string]interface{} `json:"identity_data,omitempty"` + IdentityId *openapi_types.UUID `json:"identity_id,omitempty"` + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty"` + Provider *string `json:"provider,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + UserId *openapi_types.UUID `json:"user_id,omitempty"` +} + +// MFAFactorSchema Represents a MFA factor. +type MFAFactorSchema struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + + // FactorType Usually one of: + // - totp + // - phone + // - webauthn + FactorType *string `json:"factor_type,omitempty"` + FriendlyName *string `json:"friendly_name,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + LastChallengedAt *time.Time `json:"last_challenged_at"` + Phone *string `json:"phone"` + + // Status Usually one of: + // - verified + // - unverified + Status *string `json:"status,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + WebAuthnCredential *string `json:"web_authn_credential,omitempty"` +} + +// SAMLAttributeMappingSchema defines model for SAMLAttributeMappingSchema. +type SAMLAttributeMappingSchema struct { + Keys *map[string]interface{} `json:"keys,omitempty"` +} + +// SSOProviderSchema defines model for SSOProviderSchema. +type SSOProviderSchema struct { + Id *openapi_types.UUID `json:"id,omitempty"` + Saml *struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + EntityId *string `json:"entity_id,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` + } `json:"saml,omitempty"` + SsoDomains *[]struct { + Domain *string `json:"domain,omitempty"` + } `json:"sso_domains,omitempty"` +} + +// UserSchema Object describing the user related to the issued access and refresh tokens. +type UserSchema struct { + AppMetadata *map[string]interface{} `json:"app_metadata,omitempty"` + // Deprecated: + Aud *string `json:"aud,omitempty"` + BannedUntil *time.Time `json:"banned_until,omitempty"` + ConfirmationSentAt *time.Time `json:"confirmation_sent_at,omitempty"` + ConfirmedAt *time.Time `json:"confirmed_at,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + + // Email User's primary contact email. In most cases you can uniquely identify a user by their email address, but not in all cases. + Email *string `json:"email,omitempty"` + EmailChangeSentAt *time.Time `json:"email_change_sent_at,omitempty"` + EmailConfirmedAt *time.Time `json:"email_confirmed_at,omitempty"` + Factors *[]MFAFactorSchema `json:"factors,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + Identities *[]IdentitySchema `json:"identities,omitempty"` + IsAnonymous *bool `json:"is_anonymous,omitempty"` + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty"` + NewEmail *openapi_types.Email `json:"new_email,omitempty"` + NewPhone *string `json:"new_phone,omitempty"` + + // Phone User's primary contact phone number. In most cases you can uniquely identify a user by their phone number, but not in all cases. + Phone *string `json:"phone,omitempty"` + PhoneChangeSentAt *time.Time `json:"phone_change_sent_at,omitempty"` + PhoneConfirmedAt *time.Time `json:"phone_confirmed_at,omitempty"` + ReauthenticationSentAt *time.Time `json:"reauthentication_sent_at,omitempty"` + RecoverySentAt *time.Time `json:"recovery_sent_at,omitempty"` + Role *string `json:"role,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + UserMetadata *map[string]interface{} `json:"user_metadata,omitempty"` +} + +// BadRequestResponse defines model for BadRequestResponse. +type BadRequestResponse = ErrorSchema + +// ForbiddenResponse defines model for ForbiddenResponse. +type ForbiddenResponse = ErrorSchema + +// UnauthorizedResponse defines model for UnauthorizedResponse. +type UnauthorizedResponse = ErrorSchema + +// GetAdminAuditParams defines parameters for GetAdminAudit. +type GetAdminAuditParams struct { + Page *int `form:"page,omitempty" json:"page,omitempty"` + PerPage *int `form:"per_page,omitempty" json:"per_page,omitempty"` +} + +// PostAdminGenerateLinkJSONBody defines parameters for PostAdminGenerateLink. +type PostAdminGenerateLinkJSONBody struct { + Data *map[string]interface{} `json:"data,omitempty"` + Email openapi_types.Email `json:"email"` + NewEmail *openapi_types.Email `json:"new_email,omitempty"` + Password *string `json:"password,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + Type PostAdminGenerateLinkJSONBodyType `json:"type"` +} + +// PostAdminGenerateLinkJSONBodyType defines parameters for PostAdminGenerateLink. +type PostAdminGenerateLinkJSONBodyType string + +// PostAdminSsoProvidersJSONBody defines parameters for PostAdminSsoProviders. +type PostAdminSsoProvidersJSONBody struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + Domains *[]string `json:"domains,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` + Type PostAdminSsoProvidersJSONBodyType `json:"type"` +} + +// PostAdminSsoProvidersJSONBodyType defines parameters for PostAdminSsoProviders. +type PostAdminSsoProvidersJSONBodyType string + +// PutAdminSsoProvidersSsoProviderIdJSONBody defines parameters for PutAdminSsoProvidersSsoProviderId. +type PutAdminSsoProvidersSsoProviderIdJSONBody struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + Domains *[]string `json:"domains,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` +} + +// GetAdminUsersParams defines parameters for GetAdminUsers. +type GetAdminUsersParams struct { + Page *int `form:"page,omitempty" json:"page,omitempty"` + PerPage *int `form:"per_page,omitempty" json:"per_page,omitempty"` +} + +// PutAdminUsersUserIdFactorsFactorIdJSONBody defines parameters for PutAdminUsersUserIdFactorsFactorId. +type PutAdminUsersUserIdFactorsFactorIdJSONBody = map[string]interface{} + +// PostInviteJSONBody defines parameters for PostInvite. +type PostInviteJSONBody struct { + Data *map[string]interface{} `json:"data,omitempty"` + Email string `json:"email"` +} + +// PostAdminGenerateLinkJSONRequestBody defines body for PostAdminGenerateLink for application/json ContentType. +type PostAdminGenerateLinkJSONRequestBody PostAdminGenerateLinkJSONBody + +// PostAdminSsoProvidersJSONRequestBody defines body for PostAdminSsoProviders for application/json ContentType. +type PostAdminSsoProvidersJSONRequestBody PostAdminSsoProvidersJSONBody + +// PutAdminSsoProvidersSsoProviderIdJSONRequestBody defines body for PutAdminSsoProvidersSsoProviderId for application/json ContentType. +type PutAdminSsoProvidersSsoProviderIdJSONRequestBody PutAdminSsoProvidersSsoProviderIdJSONBody + +// PutAdminUsersUserIdJSONRequestBody defines body for PutAdminUsersUserId for application/json ContentType. +type PutAdminUsersUserIdJSONRequestBody = UserSchema + +// PutAdminUsersUserIdFactorsFactorIdJSONRequestBody defines body for PutAdminUsersUserIdFactorsFactorId for application/json ContentType. +type PutAdminUsersUserIdFactorsFactorIdJSONRequestBody = PutAdminUsersUserIdFactorsFactorIdJSONBody + +// PostInviteJSONRequestBody defines body for PostInvite for application/json ContentType. +type PostInviteJSONRequestBody PostInviteJSONBody + +// RequestEditorFn is the function signature for the RequestEditor callback function +type RequestEditorFn func(ctx context.Context, req *http.Request) error + +// Doer performs HTTP requests. +// +// The standard http.Client implements this interface. +type HttpRequestDoer interface { + Do(req *http.Request) (*http.Response, error) +} + +// Client which conforms to the OpenAPI3 specification for this service. +type Client struct { + // The endpoint of the server conforming to this interface, with scheme, + // https://api.deepmap.com for example. This can contain a path relative + // to the server, such as https://api.deepmap.com/dev-test, and all the + // paths in the swagger spec will be appended to the server. + Server string + + // Doer for performing requests, typically a *http.Client with any + // customized settings, such as certificate chains. + Client HttpRequestDoer + + // A list of callbacks for modifying requests which are generated before sending over + // the network. + RequestEditors []RequestEditorFn +} + +// ClientOption allows setting custom parameters during construction +type ClientOption func(*Client) error + +// Creates a new Client, with reasonable defaults +func NewClient(server string, opts ...ClientOption) (*Client, error) { + // create a client with sane default values + client := Client{ + Server: server, + } + // mutate client and add all optional params + for _, o := range opts { + if err := o(&client); err != nil { + return nil, err + } + } + // ensure the server URL always has a trailing slash + if !strings.HasSuffix(client.Server, "/") { + client.Server += "/" + } + // create httpClient, if not already present + if client.Client == nil { + client.Client = &http.Client{} + } + return &client, nil +} + +// WithHTTPClient allows overriding the default Doer, which is +// automatically created using http.Client. This is useful for tests. +func WithHTTPClient(doer HttpRequestDoer) ClientOption { + return func(c *Client) error { + c.Client = doer + return nil + } +} + +// WithRequestEditorFn allows setting up a callback function, which will be +// called right before sending the request. This can be used to mutate the request. +func WithRequestEditorFn(fn RequestEditorFn) ClientOption { + return func(c *Client) error { + c.RequestEditors = append(c.RequestEditors, fn) + return nil + } +} + +// The interface specification for the client above. +type ClientInterface interface { + // GetAdminAudit request + GetAdminAudit(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostAdminGenerateLinkWithBody request with any body + PostAdminGenerateLinkWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostAdminGenerateLink(ctx context.Context, body PostAdminGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminSsoProviders request + GetAdminSsoProviders(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostAdminSsoProvidersWithBody request with any body + PostAdminSsoProvidersWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostAdminSsoProviders(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminSsoProvidersSsoProviderId request + DeleteAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminSsoProvidersSsoProviderId request + GetAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminSsoProvidersSsoProviderIdWithBody request with any body + PutAdminSsoProvidersSsoProviderIdWithBody(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsers request + GetAdminUsers(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminUsersUserId request + DeleteAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsersUserId request + GetAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminUsersUserIdWithBody request with any body + PutAdminUsersUserIdWithBody(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsersUserIdFactors request + GetAdminUsersUserIdFactors(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminUsersUserIdFactorsFactorId request + DeleteAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminUsersUserIdFactorsFactorIdWithBody request with any body + PutAdminUsersUserIdFactorsFactorIdWithBody(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostInviteWithBody request with any body + PostInviteWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostInvite(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) +} + +func (c *Client) GetAdminAudit(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminAuditRequest(c.Server, params) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminGenerateLinkWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminGenerateLinkRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminGenerateLink(ctx context.Context, body PostAdminGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminGenerateLinkRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminSsoProviders(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminSsoProvidersRequest(c.Server) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminSsoProvidersWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminSsoProvidersRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminSsoProviders(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminSsoProvidersRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminSsoProvidersSsoProviderIdWithBody(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(c.Server, ssoProviderId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsers(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersRequest(c.Server, params) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminUsersUserIdRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersUserIdRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdWithBody(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdRequestWithBody(c.Server, userId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdRequest(c.Server, userId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsersUserIdFactors(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersUserIdFactorsRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminUsersUserIdFactorsFactorIdRequest(c.Server, userId, factorId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdFactorsFactorIdWithBody(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(c.Server, userId, factorId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdFactorsFactorIdRequest(c.Server, userId, factorId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostInviteWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostInviteRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostInvite(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostInviteRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +// NewGetAdminAuditRequest generates requests for GetAdminAudit +func NewGetAdminAuditRequest(server string, params *GetAdminAuditParams) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/audit") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + if params != nil { + queryValues := queryURL.Query() + + if params.Page != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "page", runtime.ParamLocationQuery, *params.Page); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + if params.PerPage != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "per_page", runtime.ParamLocationQuery, *params.PerPage); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + queryURL.RawQuery = queryValues.Encode() + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPostAdminGenerateLinkRequest calls the generic PostAdminGenerateLink builder with application/json body +func NewPostAdminGenerateLinkRequest(server string, body PostAdminGenerateLinkJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostAdminGenerateLinkRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostAdminGenerateLinkRequestWithBody generates requests for PostAdminGenerateLink with any type of body +func NewPostAdminGenerateLinkRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/generate_link") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewGetAdminSsoProvidersRequest generates requests for GetAdminSsoProviders +func NewGetAdminSsoProvidersRequest(server string) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPostAdminSsoProvidersRequest calls the generic PostAdminSsoProviders builder with application/json body +func NewPostAdminSsoProvidersRequest(server string, body PostAdminSsoProvidersJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostAdminSsoProvidersRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostAdminSsoProvidersRequestWithBody generates requests for PostAdminSsoProviders with any type of body +func NewPostAdminSsoProvidersRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewDeleteAdminSsoProvidersSsoProviderIdRequest generates requests for DeleteAdminSsoProvidersSsoProviderId +func NewDeleteAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewGetAdminSsoProvidersSsoProviderIdRequest generates requests for GetAdminSsoProvidersSsoProviderId +func NewGetAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminSsoProvidersSsoProviderIdRequest calls the generic PutAdminSsoProvidersSsoProviderId builder with application/json body +func NewPutAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(server, ssoProviderId, "application/json", bodyReader) +} + +// NewPutAdminSsoProvidersSsoProviderIdRequestWithBody generates requests for PutAdminSsoProvidersSsoProviderId with any type of body +func NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(server string, ssoProviderId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewGetAdminUsersRequest generates requests for GetAdminUsers +func NewGetAdminUsersRequest(server string, params *GetAdminUsersParams) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + if params != nil { + queryValues := queryURL.Query() + + if params.Page != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "page", runtime.ParamLocationQuery, *params.Page); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + if params.PerPage != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "per_page", runtime.ParamLocationQuery, *params.PerPage); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + queryURL.RawQuery = queryValues.Encode() + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewDeleteAdminUsersUserIdRequest generates requests for DeleteAdminUsersUserId +func NewDeleteAdminUsersUserIdRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewGetAdminUsersUserIdRequest generates requests for GetAdminUsersUserId +func NewGetAdminUsersUserIdRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminUsersUserIdRequest calls the generic PutAdminUsersUserId builder with application/json body +func NewPutAdminUsersUserIdRequest(server string, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminUsersUserIdRequestWithBody(server, userId, "application/json", bodyReader) +} + +// NewPutAdminUsersUserIdRequestWithBody generates requests for PutAdminUsersUserId with any type of body +func NewPutAdminUsersUserIdRequestWithBody(server string, userId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewGetAdminUsersUserIdFactorsRequest generates requests for GetAdminUsersUserIdFactors +func NewGetAdminUsersUserIdFactorsRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewDeleteAdminUsersUserIdFactorsFactorIdRequest generates requests for DeleteAdminUsersUserIdFactorsFactorId +func NewDeleteAdminUsersUserIdFactorsFactorIdRequest(server string, userId openapi_types.UUID, factorId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "factorId", runtime.ParamLocationPath, factorId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors/%s", pathParam0, pathParam1) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminUsersUserIdFactorsFactorIdRequest calls the generic PutAdminUsersUserIdFactorsFactorId builder with application/json body +func NewPutAdminUsersUserIdFactorsFactorIdRequest(server string, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(server, userId, factorId, "application/json", bodyReader) +} + +// NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody generates requests for PutAdminUsersUserIdFactorsFactorId with any type of body +func NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(server string, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "factorId", runtime.ParamLocationPath, factorId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors/%s", pathParam0, pathParam1) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewPostInviteRequest calls the generic PostInvite builder with application/json body +func NewPostInviteRequest(server string, body PostInviteJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostInviteRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostInviteRequestWithBody generates requests for PostInvite with any type of body +func NewPostInviteRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/invite") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error { + for _, r := range c.RequestEditors { + if err := r(ctx, req); err != nil { + return err + } + } + for _, r := range additionalEditors { + if err := r(ctx, req); err != nil { + return err + } + } + return nil +} + +// ClientWithResponses builds on ClientInterface to offer response payloads +type ClientWithResponses struct { + ClientInterface +} + +// NewClientWithResponses creates a new ClientWithResponses, which wraps +// Client with return type handling +func NewClientWithResponses(server string, opts ...ClientOption) (*ClientWithResponses, error) { + client, err := NewClient(server, opts...) + if err != nil { + return nil, err + } + return &ClientWithResponses{client}, nil +} + +// WithBaseURL overrides the baseURL. +func WithBaseURL(baseURL string) ClientOption { + return func(c *Client) error { + newBaseURL, err := url.Parse(baseURL) + if err != nil { + return err + } + c.Server = newBaseURL.String() + return nil + } +} + +// ClientWithResponsesInterface is the interface specification for the client with responses above. +type ClientWithResponsesInterface interface { + // GetAdminAuditWithResponse request + GetAdminAuditWithResponse(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*GetAdminAuditResponse, error) + + // PostAdminGenerateLinkWithBodyWithResponse request with any body + PostAdminGenerateLinkWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminGenerateLinkResponse, error) + + PostAdminGenerateLinkWithResponse(ctx context.Context, body PostAdminGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminGenerateLinkResponse, error) + + // GetAdminSsoProvidersWithResponse request + GetAdminSsoProvidersWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersResponse, error) + + // PostAdminSsoProvidersWithBodyWithResponse request with any body + PostAdminSsoProvidersWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) + + PostAdminSsoProvidersWithResponse(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) + + // DeleteAdminSsoProvidersSsoProviderIdWithResponse request + DeleteAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) + + // GetAdminSsoProvidersSsoProviderIdWithResponse request + GetAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersSsoProviderIdResponse, error) + + // PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse request with any body + PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) + + PutAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) + + // GetAdminUsersWithResponse request + GetAdminUsersWithResponse(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*GetAdminUsersResponse, error) + + // DeleteAdminUsersUserIdWithResponse request + DeleteAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdResponse, error) + + // GetAdminUsersUserIdWithResponse request + GetAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdResponse, error) + + // PutAdminUsersUserIdWithBodyWithResponse request with any body + PutAdminUsersUserIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) + + PutAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) + + // GetAdminUsersUserIdFactorsWithResponse request + GetAdminUsersUserIdFactorsWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdFactorsResponse, error) + + // DeleteAdminUsersUserIdFactorsFactorIdWithResponse request + DeleteAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) + + // PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse request with any body + PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) + + PutAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) + + // PostInviteWithBodyWithResponse request with any body + PostInviteWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) + + PostInviteWithResponse(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) +} + +type GetAdminAuditResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *[]struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + IpAddress *string `json:"ip_address,omitempty"` + Payload *struct { + // Action Usually one of these values: + // - login + // - logout + // - invite_accepted + // - user_signedup + // - user_invited + // - user_deleted + // - user_modified + // - user_recovery_requested + // - user_reauthenticate_requested + // - user_confirmation_requested + // - user_repeated_signup + // - user_updated_password + // - token_revoked + // - token_refreshed + // - generate_recovery_codes + // - factor_in_progress + // - factor_unenrolled + // - challenge_created + // - verification_attempted + // - factor_deleted + // - recovery_codes_deleted + // - factor_updated + // - mfa_code_login + Action *string `json:"action,omitempty"` + ActorId *string `json:"actor_id,omitempty"` + ActorName *string `json:"actor_name,omitempty"` + ActorUsername *string `json:"actor_username,omitempty"` + + // ActorViaSso Whether the actor used a SSO protocol (like SAML 2.0 or OIDC) to authenticate. + ActorViaSso *bool `json:"actor_via_sso,omitempty"` + + // LogType Usually one of these values: + // - account + // - team + // - token + // - user + // - factor + // - recovery_codes + LogType *string `json:"log_type,omitempty"` + Traits *map[string]interface{} `json:"traits,omitempty"` + } `json:"payload,omitempty"` + } + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse +} + +// Status returns HTTPResponse.Status +func (r GetAdminAuditResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminAuditResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostAdminGenerateLinkResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + ActionLink *string `json:"action_link,omitempty"` + EmailOtp *string `json:"email_otp,omitempty"` + HashedToken *string `json:"hashed_token,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + VerificationType *string `json:"verification_type,omitempty"` + AdditionalProperties map[string]interface{} `json:"-"` + } + JSON400 *BadRequestResponse + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema + JSON422 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PostAdminGenerateLinkResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostAdminGenerateLinkResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminSsoProvidersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + Items *[]SSOProviderSchema `json:"items,omitempty"` + } +} + +// Status returns HTTPResponse.Status +func (r GetAdminSsoProvidersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminSsoProvidersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostAdminSsoProvidersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON400 *BadRequestResponse + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse +} + +// Status returns HTTPResponse.Status +func (r PostAdminSsoProvidersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostAdminSsoProvidersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON400 *BadRequestResponse + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + // Deprecated: + Aud *string `json:"aud,omitempty"` + Users *[]UserSchema `json:"users,omitempty"` + } + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersUserIdFactorsResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *[]MFAFactorSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersUserIdFactorsResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersUserIdFactorsResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminUsersUserIdFactorsFactorIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *MFAFactorSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminUsersUserIdFactorsFactorIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminUsersUserIdFactorsFactorIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminUsersUserIdFactorsFactorIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *MFAFactorSchema + JSON401 *UnauthorizedResponse + JSON403 *ForbiddenResponse + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminUsersUserIdFactorsFactorIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminUsersUserIdFactorsFactorIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostInviteResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON400 *BadRequestResponse + JSON422 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PostInviteResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostInviteResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +// GetAdminAuditWithResponse request returning *GetAdminAuditResponse +func (c *ClientWithResponses) GetAdminAuditWithResponse(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*GetAdminAuditResponse, error) { + rsp, err := c.GetAdminAudit(ctx, params, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminAuditResponse(rsp) +} + +// PostAdminGenerateLinkWithBodyWithResponse request with arbitrary body returning *PostAdminGenerateLinkResponse +func (c *ClientWithResponses) PostAdminGenerateLinkWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminGenerateLinkResponse, error) { + rsp, err := c.PostAdminGenerateLinkWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminGenerateLinkResponse(rsp) +} + +func (c *ClientWithResponses) PostAdminGenerateLinkWithResponse(ctx context.Context, body PostAdminGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminGenerateLinkResponse, error) { + rsp, err := c.PostAdminGenerateLink(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminGenerateLinkResponse(rsp) +} + +// GetAdminSsoProvidersWithResponse request returning *GetAdminSsoProvidersResponse +func (c *ClientWithResponses) GetAdminSsoProvidersWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersResponse, error) { + rsp, err := c.GetAdminSsoProviders(ctx, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminSsoProvidersResponse(rsp) +} + +// PostAdminSsoProvidersWithBodyWithResponse request with arbitrary body returning *PostAdminSsoProvidersResponse +func (c *ClientWithResponses) PostAdminSsoProvidersWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) { + rsp, err := c.PostAdminSsoProvidersWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminSsoProvidersResponse(rsp) +} + +func (c *ClientWithResponses) PostAdminSsoProvidersWithResponse(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) { + rsp, err := c.PostAdminSsoProviders(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminSsoProvidersResponse(rsp) +} + +// DeleteAdminSsoProvidersSsoProviderIdWithResponse request returning *DeleteAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) DeleteAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.DeleteAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// GetAdminSsoProvidersSsoProviderIdWithResponse request returning *GetAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) GetAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.GetAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse request with arbitrary body returning *PutAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.PutAdminSsoProvidersSsoProviderIdWithBody(ctx, ssoProviderId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.PutAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// GetAdminUsersWithResponse request returning *GetAdminUsersResponse +func (c *ClientWithResponses) GetAdminUsersWithResponse(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*GetAdminUsersResponse, error) { + rsp, err := c.GetAdminUsers(ctx, params, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersResponse(rsp) +} + +// DeleteAdminUsersUserIdWithResponse request returning *DeleteAdminUsersUserIdResponse +func (c *ClientWithResponses) DeleteAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdResponse, error) { + rsp, err := c.DeleteAdminUsersUserId(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminUsersUserIdResponse(rsp) +} + +// GetAdminUsersUserIdWithResponse request returning *GetAdminUsersUserIdResponse +func (c *ClientWithResponses) GetAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdResponse, error) { + rsp, err := c.GetAdminUsersUserId(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersUserIdResponse(rsp) +} + +// PutAdminUsersUserIdWithBodyWithResponse request with arbitrary body returning *PutAdminUsersUserIdResponse +func (c *ClientWithResponses) PutAdminUsersUserIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdWithBody(ctx, userId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) { + rsp, err := c.PutAdminUsersUserId(ctx, userId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdResponse(rsp) +} + +// GetAdminUsersUserIdFactorsWithResponse request returning *GetAdminUsersUserIdFactorsResponse +func (c *ClientWithResponses) GetAdminUsersUserIdFactorsWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdFactorsResponse, error) { + rsp, err := c.GetAdminUsersUserIdFactors(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersUserIdFactorsResponse(rsp) +} + +// DeleteAdminUsersUserIdFactorsFactorIdWithResponse request returning *DeleteAdminUsersUserIdFactorsFactorIdResponse +func (c *ClientWithResponses) DeleteAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.DeleteAdminUsersUserIdFactorsFactorId(ctx, userId, factorId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +// PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse request with arbitrary body returning *PutAdminUsersUserIdFactorsFactorIdResponse +func (c *ClientWithResponses) PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdFactorsFactorIdWithBody(ctx, userId, factorId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdFactorsFactorId(ctx, userId, factorId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +// PostInviteWithBodyWithResponse request with arbitrary body returning *PostInviteResponse +func (c *ClientWithResponses) PostInviteWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) { + rsp, err := c.PostInviteWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostInviteResponse(rsp) +} + +func (c *ClientWithResponses) PostInviteWithResponse(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) { + rsp, err := c.PostInvite(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostInviteResponse(rsp) +} + +// ParseGetAdminAuditResponse parses an HTTP response from a GetAdminAuditWithResponse call +func ParseGetAdminAuditResponse(rsp *http.Response) (*GetAdminAuditResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminAuditResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest []struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + IpAddress *string `json:"ip_address,omitempty"` + Payload *struct { + // Action Usually one of these values: + // - login + // - logout + // - invite_accepted + // - user_signedup + // - user_invited + // - user_deleted + // - user_modified + // - user_recovery_requested + // - user_reauthenticate_requested + // - user_confirmation_requested + // - user_repeated_signup + // - user_updated_password + // - token_revoked + // - token_refreshed + // - generate_recovery_codes + // - factor_in_progress + // - factor_unenrolled + // - challenge_created + // - verification_attempted + // - factor_deleted + // - recovery_codes_deleted + // - factor_updated + // - mfa_code_login + Action *string `json:"action,omitempty"` + ActorId *string `json:"actor_id,omitempty"` + ActorName *string `json:"actor_name,omitempty"` + ActorUsername *string `json:"actor_username,omitempty"` + + // ActorViaSso Whether the actor used a SSO protocol (like SAML 2.0 or OIDC) to authenticate. + ActorViaSso *bool `json:"actor_via_sso,omitempty"` + + // LogType Usually one of these values: + // - account + // - team + // - token + // - user + // - factor + // - recovery_codes + LogType *string `json:"log_type,omitempty"` + Traits *map[string]interface{} `json:"traits,omitempty"` + } `json:"payload,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParsePostAdminGenerateLinkResponse parses an HTTP response from a PostAdminGenerateLinkWithResponse call +func ParsePostAdminGenerateLinkResponse(rsp *http.Response) (*PostAdminGenerateLinkResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostAdminGenerateLinkResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + ActionLink *string `json:"action_link,omitempty"` + EmailOtp *string `json:"email_otp,omitempty"` + HashedToken *string `json:"hashed_token,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + VerificationType *string `json:"verification_type,omitempty"` + AdditionalProperties map[string]interface{} `json:"-"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest BadRequestResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + } + + return response, nil +} + +// ParseGetAdminSsoProvidersResponse parses an HTTP response from a GetAdminSsoProvidersWithResponse call +func ParseGetAdminSsoProvidersResponse(rsp *http.Response) (*GetAdminSsoProvidersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminSsoProvidersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + Items *[]SSOProviderSchema `json:"items,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + } + + return response, nil +} + +// ParsePostAdminSsoProvidersResponse parses an HTTP response from a PostAdminSsoProvidersWithResponse call +func ParsePostAdminSsoProvidersResponse(rsp *http.Response) (*PostAdminSsoProvidersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostAdminSsoProvidersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest BadRequestResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a DeleteAdminSsoProvidersSsoProviderIdWithResponse call +func ParseDeleteAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a GetAdminSsoProvidersSsoProviderIdWithResponse call +func ParseGetAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*GetAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a PutAdminSsoProvidersSsoProviderIdWithResponse call +func ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest BadRequestResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersResponse parses an HTTP response from a GetAdminUsersWithResponse call +func ParseGetAdminUsersResponse(rsp *http.Response) (*GetAdminUsersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + // Deprecated: + Aud *string `json:"aud,omitempty"` + Users *[]UserSchema `json:"users,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminUsersUserIdResponse parses an HTTP response from a DeleteAdminUsersUserIdWithResponse call +func ParseDeleteAdminUsersUserIdResponse(rsp *http.Response) (*DeleteAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersUserIdResponse parses an HTTP response from a GetAdminUsersUserIdWithResponse call +func ParseGetAdminUsersUserIdResponse(rsp *http.Response) (*GetAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminUsersUserIdResponse parses an HTTP response from a PutAdminUsersUserIdWithResponse call +func ParsePutAdminUsersUserIdResponse(rsp *http.Response) (*PutAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersUserIdFactorsResponse parses an HTTP response from a GetAdminUsersUserIdFactorsWithResponse call +func ParseGetAdminUsersUserIdFactorsResponse(rsp *http.Response) (*GetAdminUsersUserIdFactorsResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersUserIdFactorsResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest []MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminUsersUserIdFactorsFactorIdResponse parses an HTTP response from a DeleteAdminUsersUserIdFactorsFactorIdWithResponse call +func ParseDeleteAdminUsersUserIdFactorsFactorIdResponse(rsp *http.Response) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminUsersUserIdFactorsFactorIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminUsersUserIdFactorsFactorIdResponse parses an HTTP response from a PutAdminUsersUserIdFactorsFactorIdWithResponse call +func ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp *http.Response) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminUsersUserIdFactorsFactorIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest UnauthorizedResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ForbiddenResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePostInviteResponse parses an HTTP response from a PostInviteWithResponse call +func ParsePostInviteResponse(rsp *http.Response) (*PostInviteResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostInviteResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest BadRequestResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + } + + return response, nil +} diff --git a/client/admin/gen.go b/client/admin/gen.go new file mode 100644 index 000000000..8952b812f --- /dev/null +++ b/client/admin/gen.go @@ -0,0 +1,3 @@ +package admin + +//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config ./oapi-codegen.yaml ../../openapi.yaml diff --git a/client/admin/oapi-codegen.yaml b/client/admin/oapi-codegen.yaml new file mode 100644 index 000000000..a3aa63407 --- /dev/null +++ b/client/admin/oapi-codegen.yaml @@ -0,0 +1,7 @@ +package: admin +generate: + - client + - types +include-tags: + - admin +output: client.go diff --git a/cmd/admin_cmd.go b/cmd/admin_cmd.go index 24a2f3066..7997bb5c5 100644 --- a/cmd/admin_cmd.go +++ b/cmd/admin_cmd.go @@ -1,18 +1,18 @@ package cmd import ( - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" "github.com/gofrs/uuid" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" ) -var autoconfirm, isSuperAdmin, isAdmin bool -var audience, instanceID string +var autoconfirm, isAdmin bool +var audience string -func getAudience(c *conf.Configuration) string { +func getAudience(c *conf.GlobalConfiguration) string { if audience == "" { return c.JWT.Aud } @@ -27,10 +27,8 @@ func adminCmd() *cobra.Command { adminCmd.AddCommand(&adminCreateUserCmd, &adminDeleteUserCmd) adminCmd.PersistentFlags().StringVarP(&audience, "aud", "a", "", "Set the new user's audience") - adminCmd.PersistentFlags().StringVarP(&instanceID, "instance_id", "i", "", "Set the instance ID to interact with") adminCreateUserCmd.Flags().BoolVar(&autoconfirm, "confirm", false, "Automatically confirm user without sending an email") - adminCreateUserCmd.Flags().BoolVar(&isSuperAdmin, "superadmin", false, "Create user with superadmin privileges") adminCreateUserCmd.Flags().BoolVar(&isAdmin, "admin", false, "Create user with admin privileges") return adminCmd @@ -60,34 +58,24 @@ var adminDeleteUserCmd = cobra.Command{ }, } -var adminEditRoleCmd = cobra.Command{ - Use: "editrole", - Run: func(cmd *cobra.Command, args []string) { - execWithConfigAndArgs(cmd, adminEditRole, args) - }, -} - -func adminCreateUser(globalConfig *conf.GlobalConfiguration, config *conf.Configuration, args []string) { - iid := uuid.Must(uuid.FromString(instanceID)) - - db, err := storage.Dial(globalConfig) +func adminCreateUser(config *conf.GlobalConfiguration, args []string) { + db, err := storage.Dial(config) if err != nil { logrus.Fatalf("Error opening database: %+v", err) } defer db.Close() aud := getAudience(config) - if exists, err := models.IsDuplicatedEmail(db, iid, args[0], aud); exists { + if user, err := models.IsDuplicatedEmail(db, args[0], aud, nil); user != nil { logrus.Fatalf("Error creating new user: user already exists") } else if err != nil { logrus.Fatalf("Error checking user email: %+v", err) } - user, err := models.NewUser(iid, args[0], args[1], aud, nil) + user, err := models.NewUser("", args[0], args[1], aud, nil) if err != nil { logrus.Fatalf("Error creating new user: %+v", err) } - user.IsSuperAdmin = isSuperAdmin err = db.Transaction(func(tx *storage.Connection) error { var terr error @@ -119,19 +107,17 @@ func adminCreateUser(globalConfig *conf.GlobalConfiguration, config *conf.Config logrus.Infof("Created user: %s", args[0]) } -func adminDeleteUser(globalConfig *conf.GlobalConfiguration, config *conf.Configuration, args []string) { - iid := uuid.Must(uuid.FromString(instanceID)) - - db, err := storage.Dial(globalConfig) +func adminDeleteUser(config *conf.GlobalConfiguration, args []string) { + db, err := storage.Dial(config) if err != nil { logrus.Fatalf("Error opening database: %+v", err) } defer db.Close() - user, err := models.FindUserByEmailAndAudience(db, iid, args[0], getAudience(config)) + user, err := models.FindUserByEmailAndAudience(db, args[0], getAudience(config)) if err != nil { userID := uuid.Must(uuid.FromString(args[0])) - user, err = models.FindUserByInstanceIDAndID(db, iid, userID) + user, err = models.FindUserByID(db, userID) if err != nil { logrus.Fatalf("Error finding user (%s): %+v", userID, err) } @@ -143,36 +129,3 @@ func adminDeleteUser(globalConfig *conf.GlobalConfiguration, config *conf.Config logrus.Infof("Removed user: %s", args[0]) } - -func adminEditRole(globalConfig *conf.GlobalConfiguration, config *conf.Configuration, args []string) { - iid := uuid.Must(uuid.FromString(instanceID)) - - db, err := storage.Dial(globalConfig) - if err != nil { - logrus.Fatalf("Error opening database: %+v", err) - } - defer db.Close() - - user, err := models.FindUserByEmailAndAudience(db, iid, args[0], getAudience(config)) - if err != nil { - userID := uuid.Must(uuid.FromString(args[0])) - user, err = models.FindUserByInstanceIDAndID(db, iid, userID) - if err != nil { - logrus.Fatalf("Error finding user (%s): %+v", userID, err) - } - } - - user.IsSuperAdmin = isSuperAdmin - - if len(args) > 0 { - user.Role = args[0] - } else if isAdmin { - user.Role = config.JWT.AdminGroupName - } - - if err = db.UpdateOnly(user, "role", "is_super_admin"); err != nil { - logrus.Fatalf("Error updating role for user (%s): %+v", args[0], err) - } - - logrus.Infof("Updated user: %s", args[0]) -} diff --git a/cmd/migrate_cmd.go b/cmd/migrate_cmd.go index e0a5b90d5..c9b80dfb8 100644 --- a/cmd/migrate_cmd.go +++ b/cmd/migrate_cmd.go @@ -1,17 +1,20 @@ package cmd import ( + "embed" + "fmt" "net/url" "os" - "github.com/gobuffalo/pop/v5" - "github.com/gobuffalo/pop/v5/logging" - "github.com/netlify/gotrue/conf" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/logging" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) +var EmbeddedMigrations embed.FS + var migrateCmd = cobra.Command{ Use: "migrate", Long: "Migrate database strucutures. This will create new tables and add missing columns and indexes.", @@ -19,10 +22,8 @@ var migrateCmd = cobra.Command{ } func migrate(cmd *cobra.Command, args []string) { - globalConfig, err := conf.LoadGlobal(configFile) - if err != nil { - logrus.Fatalf("Failed to load configuration: %+v", err) - } + globalConfig := loadGlobalConfig(cmd.Context()) + if globalConfig.DB.Driver == "" && globalConfig.DB.URL != "" { u, err := url.Parse(globalConfig.DB.URL) if err != nil { @@ -31,7 +32,7 @@ func migrate(cmd *cobra.Command, args []string) { globalConfig.DB.Driver = u.Scheme } - log := logrus.New() + log := logrus.StandardLogger() pop.Debug = false if globalConfig.Logging.Level != "" { @@ -46,19 +47,26 @@ func migrate(cmd *cobra.Command, args []string) { } if level != logrus.DebugLevel { var noopLogger = func(lvl logging.Level, s string, args ...interface{}) { - return } // Hide pop migration logging pop.SetLogger(noopLogger) } } + u, _ := url.Parse(globalConfig.DB.URL) + processedUrl := globalConfig.DB.URL + if len(u.Query()) != 0 { + processedUrl = fmt.Sprintf("%s&application_name=gotrue_migrations", processedUrl) + } else { + processedUrl = fmt.Sprintf("%s?application_name=gotrue_migrations", processedUrl) + } deets := &pop.ConnectionDetails{ Dialect: globalConfig.DB.Driver, - URL: globalConfig.DB.URL, + URL: processedUrl, } deets.Options = map[string]string{ "migration_table_name": "schema_migrations", + "Namespace": globalConfig.DB.Namespace, } db, err := pop.NewConnection(deets) @@ -71,11 +79,14 @@ func migrate(cmd *cobra.Command, args []string) { log.Fatalf("%+v", errors.Wrap(err, "checking database connection")) } - log.Debugf("Reading migrations from %s", globalConfig.DB.MigrationsPath) - mig, err := pop.NewFileMigrator(globalConfig.DB.MigrationsPath, db) + log.Debugf("Reading migrations from executable") + box, err := pop.NewMigrationBox(EmbeddedMigrations, db) if err != nil { log.Fatalf("%+v", errors.Wrap(err, "creating db migrator")) } + + mig := box.Migrator + log.Debugf("before status") if log.Level == logrus.DebugLevel { @@ -88,11 +99,11 @@ func migrate(cmd *cobra.Command, args []string) { // turn off schema dump mig.SchemaPath = "" - err = mig.Up() + count, err := mig.UpTo(0) if err != nil { log.Fatalf("%v", errors.Wrap(err, "running db migrations")) } else { - log.Infof("GoTrue migrations applied successfully") + log.WithField("count", count).Infof("GoTrue migrations applied successfully") } log.Debugf("after status") diff --git a/cmd/multi_cmd.go b/cmd/multi_cmd.go deleted file mode 100644 index 9c416ee01..000000000 --- a/cmd/multi_cmd.go +++ /dev/null @@ -1,51 +0,0 @@ -package cmd - -import ( - "context" - "fmt" - "time" - - "github.com/netlify/gotrue/api" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" -) - -var multiCmd = cobra.Command{ - Use: "multi", - Long: "Start multi-tenant API server", - Run: multi, -} - -func multi(cmd *cobra.Command, args []string) { - globalConfig, err := conf.LoadGlobal(configFile) - if err != nil { - logrus.Fatalf("Failed to load configuration: %+v", err) - } - if globalConfig.OperatorToken == "" { - logrus.Fatal("Operator token secret is required") - } - - var db *storage.Connection - // try a couple times to connect to the database - for i := 1; i <= 3; i++ { - time.Sleep(time.Duration((i-1)*100) * time.Millisecond) - db, err = storage.Dial(globalConfig) - if err == nil { - break - } - logrus.WithError(err).WithField("attempt", i).Warn("Error connecting to database") - } - if err != nil { - logrus.Fatalf("Error opening database: %+v", err) - } - defer db.Close() - - globalConfig.MultiInstanceMode = true - api := api.NewAPIWithVersion(context.Background(), globalConfig, db, Version) - - l := fmt.Sprintf("%v:%v", globalConfig.API.Host, globalConfig.API.Port) - logrus.Infof("GoTrue API started on: %s", l) - api.ListenAndServe(l) -} diff --git a/cmd/root_cmd.go b/cmd/root_cmd.go index 557591e14..e8783d463 100644 --- a/cmd/root_cmd.go +++ b/cmd/root_cmd.go @@ -1,52 +1,63 @@ package cmd import ( + "context" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" - - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" ) -var configFile = "" +var ( + configFile = "" + watchDir = "" +) var rootCmd = cobra.Command{ Use: "gotrue", Run: func(cmd *cobra.Command, args []string) { - migrate(&migrateCmd, args) - execWithConfig(cmd, serve) + migrate(cmd, args) + serve(cmd.Context()) }, } // RootCommand will setup and return the root command func RootCommand() *cobra.Command { - rootCmd.AddCommand(&serveCmd, &migrateCmd, &multiCmd, &versionCmd, adminCmd()) - rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "the config file to use") - + rootCmd.AddCommand(&serveCmd, &migrateCmd, &versionCmd, adminCmd()) + rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "base configuration file to load") + rootCmd.PersistentFlags().StringVarP(&watchDir, "config-dir", "d", "", "directory containing a sorted list of config files to watch for changes") return &rootCmd } -func execWithConfig(cmd *cobra.Command, fn func(globalConfig *conf.GlobalConfiguration, config *conf.Configuration)) { - globalConfig, err := conf.LoadGlobal(configFile) - if err != nil { - logrus.Fatalf("Failed to load configuration: %+v", err) +func loadGlobalConfig(ctx context.Context) *conf.GlobalConfiguration { + if ctx == nil { + panic("context must not be nil") } - config, err := conf.LoadConfig(configFile) + + config, err := conf.LoadGlobal(configFile) if err != nil { logrus.Fatalf("Failed to load configuration: %+v", err) } - fn(globalConfig, config) -} + if err := observability.ConfigureLogging(&config.Logging); err != nil { + logrus.WithError(err).Error("unable to configure logging") + } -func execWithConfigAndArgs(cmd *cobra.Command, fn func(globalConfig *conf.GlobalConfiguration, config *conf.Configuration, args []string), args []string) { - globalConfig, err := conf.LoadGlobal(configFile) - if err != nil { - logrus.Fatalf("Failed to load configuration: %+v", err) + if err := observability.ConfigureTracing(ctx, &config.Tracing); err != nil { + logrus.WithError(err).Error("unable to configure tracing") } - config, err := conf.LoadConfig(configFile) - if err != nil { - logrus.Fatalf("Failed to load configuration: %+v", err) + + if err := observability.ConfigureMetrics(ctx, &config.Metrics); err != nil { + logrus.WithError(err).Error("unable to configure metrics") + } + + if err := observability.ConfigureProfiler(ctx, &config.Profiler); err != nil { + logrus.WithError(err).Error("unable to configure profiler") } + return config +} - fn(globalConfig, config, args) +func execWithConfigAndArgs(cmd *cobra.Command, fn func(config *conf.GlobalConfiguration, args []string), args []string) { + fn(loadGlobalConfig(cmd.Context()), args) } diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index 4b5cbc0d1..2cbd3db48 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -2,38 +2,128 @@ package cmd import ( "context" - "fmt" + "net" + "net/http" + "sync" + "syscall" + "time" - "github.com/netlify/gotrue/api" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/gofrs/uuid" + "golang.org/x/sys/unix" + + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/reloader" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" ) var serveCmd = cobra.Command{ Use: "serve", Long: "Start API server", Run: func(cmd *cobra.Command, args []string) { - execWithConfig(cmd, serve) + serve(cmd.Context()) }, } -func serve(globalConfig *conf.GlobalConfiguration, config *conf.Configuration) { - db, err := storage.Dial(globalConfig) +func serve(ctx context.Context) { + if err := conf.LoadFile(configFile); err != nil { + logrus.WithError(err).Fatal("unable to load config") + } + + if err := conf.LoadDirectory(watchDir); err != nil { + logrus.WithError(err).Error("unable to load config from watch dir") + } + + config, err := conf.LoadGlobalFromEnv() if err != nil { - logrus.Fatalf("Error opening database: %+v", err) + logrus.WithError(err).Fatal("unable to load config") } - defer db.Close() - ctx, err := api.WithInstanceConfig(context.Background(), config, uuid.Nil) + db, err := storage.Dial(config) if err != nil { - logrus.Fatalf("Error loading instance config: %+v", err) + logrus.Fatalf("error opening database: %+v", err) + } + defer db.Close() + + addr := net.JoinHostPort(config.API.Host, config.API.Port) + + opts := []api.Option{ + api.NewLimiterOptions(config), } - api := api.NewAPIWithVersion(ctx, globalConfig, db, Version) + a := api.NewAPIWithVersion(config, db, utilities.Version, opts...) + ah := reloader.NewAtomicHandler(a) + logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr) + + baseCtx, baseCancel := context.WithCancel(context.Background()) + defer baseCancel() + + httpSrv := &http.Server{ + Addr: addr, + Handler: ah, + ReadHeaderTimeout: 2 * time.Second, // to mitigate a Slowloris attack + BaseContext: func(net.Listener) context.Context { + return baseCtx + }, + } + log := logrus.WithField("component", "api") + + var wg sync.WaitGroup + defer wg.Wait() // Do not return to caller until this goroutine is done. - l := fmt.Sprintf("%v:%v", globalConfig.API.Host, globalConfig.API.Port) - logrus.Infof("GoTrue API started on: %s", l) - api.ListenAndServe(l) + if watchDir != "" { + wg.Add(1) + go func() { + defer wg.Done() + + fn := func(latestCfg *conf.GlobalConfiguration) { + log.Info("reloading api with new configuration") + latestAPI := api.NewAPIWithVersion( + latestCfg, db, utilities.Version, opts...) + ah.Store(latestAPI) + } + + rl := reloader.NewReloader(watchDir) + if err := rl.Watch(ctx, fn); err != nil { + log.WithError(err).Error("watcher is exiting") + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + <-ctx.Done() + + defer baseCancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Minute) + defer shutdownCancel() + + if err := httpSrv.Shutdown(shutdownCtx); err != nil && !errors.Is(err, context.Canceled) { + log.WithError(err).Error("shutdown failed") + } + }() + + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + var serr error + if err := c.Control(func(fd uintptr) { + serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }); err != nil { + return err + } + return serr + }, + } + listener, err := lc.Listen(ctx, "tcp", addr) + if err != nil { + log.WithError(err).Fatal("http server listen failed") + } + if err := httpSrv.Serve(listener); err != nil { + log.WithError(err).Fatal("http server serve failed") + } } diff --git a/cmd/version_cmd.go b/cmd/version_cmd.go index 01f099912..cb555d4b1 100644 --- a/cmd/version_cmd.go +++ b/cmd/version_cmd.go @@ -4,16 +4,14 @@ import ( "fmt" "github.com/spf13/cobra" + "github.com/supabase/auth/internal/utilities" ) -// Version is the SHA of the git commit from which this binary was built. -var Version string - var versionCmd = cobra.Command{ Run: showVersion, Use: "version", } func showVersion(cmd *cobra.Command, args []string) { - fmt.Println(Version) + fmt.Println(utilities.Version) } diff --git a/conf/configuration.go b/conf/configuration.go deleted file mode 100644 index 87703a580..000000000 --- a/conf/configuration.go +++ /dev/null @@ -1,410 +0,0 @@ -package conf - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "os" - "time" - - "github.com/joho/godotenv" - "github.com/kelseyhightower/envconfig" -) - -// OAuthProviderConfiguration holds all config related to external account providers. -type OAuthProviderConfiguration struct { - ClientID string `json:"client_id" split_words:"true"` - Secret string `json:"secret"` - RedirectURI string `json:"redirect_uri" split_words:"true"` - URL string `json:"url"` - Enabled bool `json:"enabled"` -} - -type EmailProviderConfiguration struct { - Enabled bool `json:"enabled" default:"true"` -} - -type SamlProviderConfiguration struct { - Enabled bool `json:"enabled"` - MetadataURL string `json:"metadata_url" envconfig:"METADATA_URL"` - APIBase string `json:"api_base" envconfig:"API_BASE"` - Name string `json:"name"` - SigningCert string `json:"signing_cert" envconfig:"SIGNING_CERT"` - SigningKey string `json:"signing_key" envconfig:"SIGNING_KEY"` -} - -// DBConfiguration holds all the database related configuration. -type DBConfiguration struct { - Driver string `json:"driver" required:"true"` - URL string `json:"url" envconfig:"DATABASE_URL" required:"true"` - MigrationsPath string `json:"migrations_path" split_words:"true" default:"./migrations"` -} - -// JWTConfiguration holds all the JWT related configuration. -type JWTConfiguration struct { - Secret string `json:"secret" required:"true"` - Exp int `json:"exp"` - Aud string `json:"aud"` - AdminGroupName string `json:"admin_group_name" split_words:"true"` - AdminRoles []string `json:"admin_roles" split_words:"true"` - DefaultGroupName string `json:"default_group_name" split_words:"true"` -} - -// GlobalConfiguration holds all the configuration that applies to all instances. -type GlobalConfiguration struct { - API struct { - Host string - Port int `envconfig:"PORT" default:"8081"` - Endpoint string - RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` - ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL"` - } - DB DBConfiguration - External ProviderConfiguration - Logging LoggingConfig `envconfig:"LOG"` - OperatorToken string `split_words:"true" required:"false"` - MultiInstanceMode bool - Tracing TracingConfig - SMTP SMTPConfiguration - RateLimitHeader string `split_words:"true"` - RateLimitEmailSent float64 `split_words:"true" default:"30"` -} - -// EmailContentConfiguration holds the configuration for emails, both subjects and template URLs. -type EmailContentConfiguration struct { - Invite string `json:"invite"` - Confirmation string `json:"confirmation"` - Recovery string `json:"recovery"` - EmailChange string `json:"email_change" split_words:"true"` - MagicLink string `json:"magic_link" split_words:"true"` -} - -type ProviderConfiguration struct { - Apple OAuthProviderConfiguration `json:"apple"` - Azure OAuthProviderConfiguration `json:"azure"` - Bitbucket OAuthProviderConfiguration `json:"bitbucket"` - Discord OAuthProviderConfiguration `json:"discord"` - Facebook OAuthProviderConfiguration `json:"facebook"` - Github OAuthProviderConfiguration `json:"github"` - Gitlab OAuthProviderConfiguration `json:"gitlab"` - Google OAuthProviderConfiguration `json:"google"` - Notion OAuthProviderConfiguration `json:"notion"` - Linkedin OAuthProviderConfiguration `json:"linkedin"` - Spotify OAuthProviderConfiguration `json:"spotify"` - Slack OAuthProviderConfiguration `json:"slack"` - Twitter OAuthProviderConfiguration `json:"twitter"` - Twitch OAuthProviderConfiguration `json:"twitch"` - Email EmailProviderConfiguration `json:"email"` - Phone PhoneProviderConfiguration `json:"phone"` - Saml SamlProviderConfiguration `json:"saml"` - IosBundleId string `json:"ios_bundle_id" split_words:"true"` - RedirectURL string `json:"redirect_url"` -} - -type SMTPConfiguration struct { - MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` - Host string `json:"host"` - Port int `json:"port,omitempty" default:"587"` - User string `json:"user"` - Pass string `json:"pass,omitempty"` - AdminEmail string `json:"admin_email" split_words:"true"` - SenderName string `json:"sender_name" split_words:"true"` -} - -type MailerConfiguration struct { - Autoconfirm bool `json:"autoconfirm"` - Subjects EmailContentConfiguration `json:"subjects"` - Templates EmailContentConfiguration `json:"templates"` - URLPaths EmailContentConfiguration `json:"url_paths"` - SecureEmailChangeEnabled bool `json:"secure_email_change_enabled" split_words:"true" default:"true"` -} - -type PhoneProviderConfiguration struct { - Enabled bool `json:"enabled"` -} - -type SmsProviderConfiguration struct { - Autoconfirm bool `json:"autoconfirm"` - MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` - OtpExp uint `json:"otp_exp" split_words:"true"` - OtpLength int `json:"otp_length" split_words:"true"` - Provider string `json:"provider"` - Template string `json:"template"` - Twilio TwilioProviderConfiguration `json:"twilio"` - Messagebird MessagebirdProviderConfiguration `json:"messagebird"` - Textlocal TextlocalProviderConfiguration `json:"textlocal"` - Vonage VonageProviderConfiguration `json:"vonage"` -} - -type TwilioProviderConfiguration struct { - AccountSid string `json:"account_sid" split_words:"true"` - AuthToken string `json:"auth_token" split_words:"true"` - MessageServiceSid string `json:"message_service_sid" split_words:"true"` -} - -type MessagebirdProviderConfiguration struct { - AccessKey string `json:"access_key" split_words:"true"` - Originator string `json:"originator" split_words:"true"` -} - -type TextlocalProviderConfiguration struct { - ApiKey string `json:"api_key" split_words:"true"` - Sender string `json:"sender" split_words:"true"` -} - -type VonageProviderConfiguration struct { - ApiKey string `json:"api_key" split_words:"true"` - ApiSecret string `json:"api_secret" split_words:"true"` - From string `json:"from" split_words:"true"` -} - -type CaptchaConfiguration struct { - Enabled bool `json:"enabled" default:"false"` - Provider string `json:"provider" default:"hcaptcha"` - Secret string `json:"provider_secret"` -} - -type SecurityConfiguration struct { - Captcha CaptchaConfiguration `json:"captcha"` - RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` -} - -// Configuration holds all the per-instance configuration. -type Configuration struct { - SiteURL string `json:"site_url" split_words:"true" required:"true"` - URIAllowList []string `json:"uri_allow_list" split_words:"true"` - PasswordMinLength int `json:"password_min_length" default:"6"` - JWT JWTConfiguration `json:"jwt"` - SMTP SMTPConfiguration `json:"smtp"` - Mailer MailerConfiguration `json:"mailer"` - External ProviderConfiguration `json:"external"` - Sms SmsProviderConfiguration `json:"sms"` - DisableSignup bool `json:"disable_signup" split_words:"true"` - Webhook WebhookConfig `json:"webhook" split_words:"true"` - Security SecurityConfiguration `json:"security"` - Cookie struct { - Key string `json:"key"` - Domain string `json:"domain"` - Duration int `json:"duration"` - } `json:"cookies"` -} - -func loadEnvironment(filename string) error { - var err error - if filename != "" { - err = godotenv.Load(filename) - } else { - err = godotenv.Load() - // handle if .env file does not exist, this is OK - if os.IsNotExist(err) { - return nil - } - } - return err -} - -type WebhookConfig struct { - URL string `json:"url"` - Retries int `json:"retries"` - TimeoutSec int `json:"timeout_sec"` - Secret string `json:"secret"` - Events []string `json:"events"` -} - -func (w *WebhookConfig) HasEvent(event string) bool { - for _, name := range w.Events { - if event == name { - return true - } - } - return false -} - -// LoadGlobal loads configuration from file and environment variables. -func LoadGlobal(filename string) (*GlobalConfiguration, error) { - if err := loadEnvironment(filename); err != nil { - return nil, err - } - - config := new(GlobalConfiguration) - if err := envconfig.Process("gotrue", config); err != nil { - return nil, err - } - - if _, err := ConfigureLogging(&config.Logging); err != nil { - return nil, err - } - - ConfigureTracing(&config.Tracing) - - if config.SMTP.MaxFrequency == 0 { - config.SMTP.MaxFrequency = 1 * time.Minute - } - return config, nil -} - -// LoadConfig loads per-instance configuration. -func LoadConfig(filename string) (*Configuration, error) { - if err := loadEnvironment(filename); err != nil { - return nil, err - } - - config := new(Configuration) - if err := envconfig.Process("gotrue", config); err != nil { - return nil, err - } - config.ApplyDefaults() - return config, nil -} - -// ApplyDefaults sets defaults for a Configuration -func (config *Configuration) ApplyDefaults() { - if config.JWT.AdminGroupName == "" { - config.JWT.AdminGroupName = "admin" - } - - if config.JWT.AdminRoles == nil || len(config.JWT.AdminRoles) == 0 { - config.JWT.AdminRoles = []string{"service_role", "supabase_admin"} - } - - if config.JWT.Exp == 0 { - config.JWT.Exp = 3600 - } - - if config.Mailer.URLPaths.Invite == "" { - config.Mailer.URLPaths.Invite = "/" - } - if config.Mailer.URLPaths.Confirmation == "" { - config.Mailer.URLPaths.Confirmation = "/" - } - if config.Mailer.URLPaths.Recovery == "" { - config.Mailer.URLPaths.Recovery = "/" - } - if config.Mailer.URLPaths.EmailChange == "" { - config.Mailer.URLPaths.EmailChange = "/" - } - - if config.SMTP.MaxFrequency == 0 { - config.SMTP.MaxFrequency = 1 * time.Minute - } - - if config.Sms.MaxFrequency == 0 { - config.Sms.MaxFrequency = 1 * time.Minute - } - - if config.Sms.OtpExp == 0 { - config.Sms.OtpExp = 60 - } - - if config.Sms.OtpLength == 0 || config.Sms.OtpLength < 6 || config.Sms.OtpLength > 10 { - // 6-digit otp by default - config.Sms.OtpLength = 6 - } - - if len(config.Sms.Template) == 0 { - config.Sms.Template = "" - } - - if config.Cookie.Key == "" { - config.Cookie.Key = "sb" - } - - if config.Cookie.Domain == "" { - config.Cookie.Domain = "" - } - - if config.Cookie.Duration == 0 { - config.Cookie.Duration = 86400 - } - - if config.URIAllowList == nil { - config.URIAllowList = []string{} - } -} - -func (config *Configuration) Value() (driver.Value, error) { - data, err := json.Marshal(config) - if err != nil { - return driver.Value(""), err - } - return driver.Value(string(data)), nil -} - -func (config *Configuration) Scan(src interface{}) error { - var source []byte - switch v := src.(type) { - case string: - source = []byte(v) - case []byte: - source = v - default: - return errors.New("Invalid data type for Configuration") - } - - if len(source) == 0 { - source = []byte("{}") - } - return json.Unmarshal(source, &config) -} - -func (o *OAuthProviderConfiguration) Validate() error { - if !o.Enabled { - return errors.New("Provider is not enabled") - } - if o.ClientID == "" { - return errors.New("Missing Oauth client ID") - } - if o.Secret == "" { - return errors.New("Missing Oauth secret") - } - if o.RedirectURI == "" { - return errors.New("Missing redirect URI") - } - return nil -} - -func (t *TwilioProviderConfiguration) Validate() error { - if t.AccountSid == "" { - return errors.New("Missing Twilio account SID") - } - if t.AuthToken == "" { - return errors.New("Missing Twilio auth token") - } - if t.MessageServiceSid == "" { - return errors.New("Missing Twilio message service SID or Twilio phone number") - } - return nil -} - -func (t *MessagebirdProviderConfiguration) Validate() error { - if t.AccessKey == "" { - return errors.New("Missing Messagebird access key") - } - if t.Originator == "" { - return errors.New("Missing Messagebird originator") - } - return nil -} - -func (t *TextlocalProviderConfiguration) Validate() error { - if t.ApiKey == "" { - return errors.New("Missing Textlocal API key") - } - if t.Sender == "" { - return errors.New("Missing Textlocal sender") - } - return nil -} - -func (t *VonageProviderConfiguration) Validate() error { - if t.ApiKey == "" { - return errors.New("Missing Vonage API key") - } - if t.ApiSecret == "" { - return errors.New("Missing Vonage API secret") - } - if t.From == "" { - return errors.New("Missing Vonage 'from' parameter") - } - return nil -} diff --git a/conf/configuration_test.go b/conf/configuration_test.go deleted file mode 100644 index dae60605c..000000000 --- a/conf/configuration_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package conf - -import ( - "os" - "testing" - - "github.com/opentracing/opentracing-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - defer os.Clearenv() - os.Exit(m.Run()) -} - -func TestGlobal(t *testing.T) { - os.Setenv("GOTRUE_DB_DRIVER", "mysql") - os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") - os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") - os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") - gc, err := LoadGlobal("") - require.NoError(t, err) - require.NotNil(t, gc) - assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) -} - -func TestTracing(t *testing.T) { - os.Setenv("GOTRUE_DB_DRIVER", "mysql") - os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") - os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") - os.Setenv("GOTRUE_TRACING_SERVICE_NAME", "identity") - os.Setenv("GOTRUE_TRACING_PORT", "8126") - os.Setenv("GOTRUE_TRACING_HOST", "127.0.0.1") - os.Setenv("GOTRUE_TRACING_TAGS", "tag1:value1,tag2:value2") - - gc, _ := LoadGlobal("") - tc := opentracing.GlobalTracer() - - assert.Equal(t, opentracing.NoopTracer{}, tc) - assert.Equal(t, false, gc.Tracing.Enabled) - assert.Equal(t, "identity", gc.Tracing.ServiceName) - assert.Equal(t, "8126", gc.Tracing.Port) - assert.Equal(t, "127.0.0.1", gc.Tracing.Host) - assert.Equal(t, map[string]string{"tag1": "value1", "tag2": "value2"}, gc.Tracing.Tags) -} diff --git a/conf/logging.go b/conf/logging.go deleted file mode 100644 index 7ecab3409..000000000 --- a/conf/logging.go +++ /dev/null @@ -1,60 +0,0 @@ -package conf - -import ( - "os" - "time" - - "github.com/sirupsen/logrus" -) - -type LoggingConfig struct { - Level string `mapstructure:"log_level" json:"log_level"` - File string `mapstructure:"log_file" json:"log_file"` - DisableColors bool `mapstructure:"disable_colors" split_words:"true" json:"disable_colors"` - QuoteEmptyFields bool `mapstructure:"quote_empty_fields" split_words:"true" json:"quote_empty_fields"` - TSFormat string `mapstructure:"ts_format" json:"ts_format"` - Fields map[string]interface{} `mapstructure:"fields" json:"fields"` -} - -func ConfigureLogging(config *LoggingConfig) (*logrus.Entry, error) { - logger := logrus.New() - - tsFormat := time.RFC3339Nano - if config.TSFormat != "" { - tsFormat = config.TSFormat - } - // always use the full timestamp - logger.SetFormatter(&logrus.TextFormatter{ - FullTimestamp: true, - DisableTimestamp: false, - TimestampFormat: tsFormat, - DisableColors: config.DisableColors, - QuoteEmptyFields: config.QuoteEmptyFields, - }) - - // use a file if you want - if config.File != "" { - f, errOpen := os.OpenFile(config.File, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0664) - if errOpen != nil { - return nil, errOpen - } - logger.SetOutput(f) - logger.Infof("Set output file to %s", config.File) - } - - if config.Level != "" { - level, err := logrus.ParseLevel(config.Level) - if err != nil { - return nil, err - } - logger.SetLevel(level) - logger.Debug("Set log level to: " + logger.GetLevel().String()) - } - - f := logrus.Fields{} - for k, v := range config.Fields { - f[k] = v - } - - return logger.WithFields(f), nil -} diff --git a/conf/tracing.go b/conf/tracing.go deleted file mode 100644 index 8f51ffe95..000000000 --- a/conf/tracing.go +++ /dev/null @@ -1,38 +0,0 @@ -package conf - -import ( - "fmt" - - "github.com/opentracing/opentracing-go" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" -) - -type TracingConfig struct { - Enabled bool `default:"false"` - Host string - Port string - ServiceName string `default:"gotrue" split_words:"true"` - Tags map[string]string -} - -func (tc *TracingConfig) tracingAddr() string { - return fmt.Sprintf("%s:%s", tc.Host, tc.Port) -} - -func ConfigureTracing(tc *TracingConfig) { - var t opentracing.Tracer = opentracing.NoopTracer{} - if tc.Enabled { - tracerOps := []tracer.StartOption{ - tracer.WithServiceName(tc.ServiceName), - tracer.WithAgentAddr(tc.tracingAddr()), - } - - for k, v := range tc.Tags { - tracerOps = append(tracerOps, tracer.WithGlobalTag(k, v)) - } - - t = opentracer.New(tracerOps...) - } - opentracing.SetGlobalTracer(t) -} diff --git a/crypto/crypto.go b/crypto/crypto.go deleted file mode 100644 index 3c788f95c..000000000 --- a/crypto/crypto.go +++ /dev/null @@ -1,39 +0,0 @@ -package crypto - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" - "math" - "math/big" - "strconv" - "strings" - - "github.com/pkg/errors" -) - -// SecureToken creates a new random token -func SecureToken() string { - b := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, b); err != nil { - panic(err.Error()) // rand should never fail - } - return removePadding(base64.URLEncoding.EncodeToString(b)) -} - -func removePadding(token string) string { - return strings.TrimRight(token, "=") -} - -// GenerateOtp generates a random n digit otp -func GenerateOtp(digits int) (string, error) { - upper := math.Pow10(digits) - val, err := rand.Int(rand.Reader, big.NewInt(int64(upper))) - if err != nil { - return "", errors.WithMessage(err, "Error generating otp") - } - expr := "%0" + strconv.Itoa(digits) + "v" - otp := fmt.Sprintf(expr, val.String()) - return otp, nil -} diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml new file mode 100644 index 000000000..47ae53d27 --- /dev/null +++ b/docker-compose-dev.yml @@ -0,0 +1,34 @@ +version: "3.9" +services: + auth: + container_name: auth + depends_on: + - postgres + build: + context: ./ + dockerfile: Dockerfile.dev + ports: + - '9999:9999' + - '9100:9100' + environment: + - GOTRUE_DB_MIGRATIONS_PATH=/go/src/github.com/supabase/auth/migrations + volumes: + - ./:/go/src/github.com/supabase/auth + command: CompileDaemon --build="make build" --directory=/go/src/github.com/supabase/auth --recursive=true -pattern="(.+\.go|.+\.env)" -exclude=auth -exclude=auth-arm64 -exclude=.env --command="/go/src/github.com/supabase/auth/auth -c=.env.docker" + postgres: + build: + context: . + dockerfile: Dockerfile.postgres.dev + container_name: auth_postgres + ports: + - '5432:5432' + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=root + - POSTGRES_DB=postgres + # sets the schema name, this should match the `NAMESPACE` env var set in your .env file + - DB_NAMESPACE=auth +volumes: + postgres_data: diff --git a/docs/admin.go b/docs/admin.go new file mode 100644 index 000000000..5c892227c --- /dev/null +++ b/docs/admin.go @@ -0,0 +1,106 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route GET /admin/users admin admin-list-users +// List all users. +// security: +// - bearer: +// responses: +// 200: adminListUserResponse +// 401: unauthorizedError + +// The list of users. +// swagger:response adminListUserResponse +type adminListUserResponseWrapper struct { + // in:body + Body api.AdminListUsersResponse +} + +// swagger:route POST /admin/users admin admin-create-user +// Returns the created user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The user to be created. +// swagger:parameters admin-create-user +type adminUserParamsWrapper struct { + // in:body + Body api.AdminUserParams +} + +// swagger:route GET /admin/user/{user_id} admin admin-get-user +// Get a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The user specified. +// swagger:response userResponse + +// swagger:route PUT /admin/user/{user_id} admin admin-update-user +// Update a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The updated user. +// swagger:response userResponse + +// swagger:route DELETE /admin/user/{user_id} admin admin-delete-user +// Deletes a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: deleteUserResponse +// 401: unauthorizedError + +// The updated user. +// swagger:response deleteUserResponse +type deleteUserResponseWrapper struct{} + +// swagger:route POST /admin/generate_link admin admin-generate-link +// Generates an email action link. +// security: +// - bearer: +// responses: +// 200: generateLinkResponse +// 401: unauthorizedError + +// swagger:parameters admin-generate-link +type generateLinkParams struct { + // in:body + Body api.GenerateLinkParams +} + +// The response object for generate link. +// swagger:response generateLinkResponse +type generateLinkResponseWrapper struct { + // in:body + Body api.GenerateLinkResponse +} diff --git a/docs/doc.go b/docs/doc.go new file mode 100644 index 000000000..5d5a1483c --- /dev/null +++ b/docs/doc.go @@ -0,0 +1,20 @@ +// Package classification gotrue +// +// Documentation of the gotrue API. +// +// Schemes: http, https +// BasePath: / +// Version: 1.0.0 +// Host: localhost:9999 +// +// SecurityDefinitions: +// bearer: +// type: apiKey +// name: Authentication +// in: header +// +// Produces: +// - application/json +// +// swagger:meta +package docs diff --git a/docs/errors.go b/docs/errors.go new file mode 100644 index 000000000..a40644585 --- /dev/null +++ b/docs/errors.go @@ -0,0 +1,6 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// This endpoint requires a bearer token. +// swagger:response unauthorizedError +type unauthorizedError struct{} diff --git a/docs/health.go b/docs/health.go new file mode 100644 index 000000000..3034fadad --- /dev/null +++ b/docs/health.go @@ -0,0 +1,15 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route GET /health health health +// The healthcheck endpoint for gotrue. Returns the current gotrue version. +// responses: +// 200: healthCheckResponse + +// swagger:response healthCheckResponse +type healthCheckResponseWrapper struct { + // in:body + Body api.HealthCheckResponse +} diff --git a/docs/invite.go b/docs/invite.go new file mode 100644 index 000000000..2775cfb4f --- /dev/null +++ b/docs/invite.go @@ -0,0 +1,18 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /invite invite invite +// Sends an invite link to the user. +// responses: +// 200: inviteResponse + +// swagger:parameters invite +type inviteParamsWrapper struct { + // in:body + Body api.InviteParams +} + +// swagger:response inviteResponse +type inviteResponseWrapper struct{} diff --git a/docs/logout.go b/docs/logout.go new file mode 100644 index 000000000..f0b112573 --- /dev/null +++ b/docs/logout.go @@ -0,0 +1,12 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// swagger:route POST /logout logout logout +// Logs out the user. +// security: +// - bearer: +// responses: +// 204: logoutResponse + +// swagger:response logoutResponse +type logoutResponseWrapper struct{} diff --git a/docs/oauth.go b/docs/oauth.go new file mode 100644 index 000000000..9b3c223bd --- /dev/null +++ b/docs/oauth.go @@ -0,0 +1,25 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// swagger:route GET /authorize oauth authorize +// Redirects the user to the 3rd-party OAuth provider to start the OAuth1.0 or OAuth2.0 authentication process. +// parameters: +// + name: redirect_to +// in: query +// description: The redirect url to return the user to after the `/callback` endpoint has completed. +// required: false +// responses: +// 302: authorizeResponse + +// Redirects user to the 3rd-party OAuth provider +// swagger:response authorizeResponse +type authorizeResponseWrapper struct{} + +// swagger:route GET /callback oauth callback +// Receives the redirect from an external provider during the OAuth authentication process. Starts the process of creating an access and refresh token. +// responses: +// 302: callbackResponse + +// Redirects user to the redirect url specified in `/authorize`. If no `redirect_url` is provided, the user will be redirected to the `SITE_URL`. +// swagger:response callbackResponse +type callbackResponseWrapper struct{} diff --git a/docs/otp.go b/docs/otp.go new file mode 100644 index 000000000..a62fa07de --- /dev/null +++ b/docs/otp.go @@ -0,0 +1,19 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /otp otp otp +// Passwordless sign-in method for email or phone. +// responses: +// 200: otpResponse + +// swagger:parameters otp +type otpParamsWrapper struct { + // Only an email or phone should be provided. + // in:body + Body api.OtpParams +} + +// swagger:response otpResponse +type otpResponseWrapper struct{} diff --git a/docs/recover.go b/docs/recover.go new file mode 100644 index 000000000..1bd249aa9 --- /dev/null +++ b/docs/recover.go @@ -0,0 +1,18 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /recover recovery recovery +// Sends a password recovery email link to the user's email. +// responses: +// 200: recoveryResponse + +// swagger:parameters recovery +type recoveryParamsWrapper struct { + // in:body + Body api.RecoverParams +} + +// swagger:response recoveryResponse +type recoveryResponseWrapper struct{} diff --git a/docs/settings.go b/docs/settings.go new file mode 100644 index 000000000..ff5d4ed31 --- /dev/null +++ b/docs/settings.go @@ -0,0 +1,15 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route GET /settings settings settings +// Returns the configuration settings for the gotrue server. +// responses: +// 200: settingsResponse + +// swagger:response settingsResponse +type settingsResponseWrapper struct { + // in:body + Body api.Settings +} diff --git a/docs/signup.go b/docs/signup.go new file mode 100644 index 000000000..a69f0155f --- /dev/null +++ b/docs/signup.go @@ -0,0 +1,17 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route POST /signup signup signup +// Password-based signup with either email or phone. +// responses: +// 200: userResponse + +// swagger:parameters signup +type signupParamsWrapper struct { + // in:body + Body api.SignupParams +} diff --git a/docs/token.go b/docs/token.go new file mode 100644 index 000000000..b4ae542ba --- /dev/null +++ b/docs/token.go @@ -0,0 +1,34 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route POST /token?grant_type=password token token-password +// Signs in a user with a password. +// responses: +// 200: tokenResponse + +// swagger:parameters token-password +type tokenPasswordGrantParamsWrapper struct { + // in:body + Body api.PasswordGrantParams +} + +// swagger:route POST /token?grant_type=refresh_token token token-refresh +// Refreshes a user's refresh token. +// responses: +// 200: tokenResponse + +// swagger:parameters token-refresh +type tokenRefreshTokenGrantParamsWrapper struct { + // in:body + Body api.RefreshTokenGrantParams +} + +// swagger:response tokenResponse +type tokenResponseWrapper struct { + // in:body + Body api.AccessTokenResponse +} diff --git a/docs/user.go b/docs/user.go new file mode 100644 index 000000000..464abfc8c --- /dev/null +++ b/docs/user.go @@ -0,0 +1,37 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/models" +) + +// swagger:route GET /user user user-get +// Get information for the logged-in user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The current user. +// swagger:response userResponse +type userResponseWrapper struct { + // in:body + Body models.User +} + +// swagger:route PUT /user user user-put +// Returns the updated user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The current user. +// swagger:parameters user-put +type userUpdateParams struct { + // in:body + Body api.UserUpdateParams +} diff --git a/docs/verify.go b/docs/verify.go new file mode 100644 index 000000000..35906505f --- /dev/null +++ b/docs/verify.go @@ -0,0 +1,24 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route GET /verify verify verify-get +// Verifies a sign up. + +// swagger:parameters verify-get +type verifyGetParamsWrapper struct { + // in:query + api.VerifyParams +} + +// swagger:route POST /verify verify verify-post +// Verifies a sign up. + +// swagger:parameters verify-post +type verifyPostParamsWrapper struct { + // in:body + Body api.VerifyParams +} diff --git a/example.docker.env b/example.docker.env new file mode 100644 index 000000000..477a5d11d --- /dev/null +++ b/example.docker.env @@ -0,0 +1,8 @@ +GOTRUE_SITE_URL="http://localhost:3000" +GOTRUE_JWT_SECRET="" +GOTRUE_DB_MIGRATIONS_PATH=/go/src/github.com/supabase/auth/migrations +GOTRUE_DB_DRIVER=postgres +DATABASE_URL=postgres://supabase_auth_admin:root@postgres:5432/postgres +GOTRUE_API_HOST=0.0.0.0 +API_EXTERNAL_URL="http://localhost:9999" +PORT=9999 diff --git a/example.env b/example.env index 2f87d7ad7..30646ea4f 100644 --- a/example.env +++ b/example.env @@ -9,7 +9,7 @@ GOTRUE_JWT_ADMIN_ROLES="supabase_admin,service_role" # Database & API connection details GOTRUE_DB_DRIVER="postgres" -NAMESPACE="auth" +DB_NAMESPACE="auth" DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" API_EXTERNAL_URL="http://localhost:9999" GOTRUE_API_HOST="localhost" @@ -49,10 +49,10 @@ GOTRUE_DISABLE_SIGNUP="false" GOTRUE_SITE_URL="http://localhost:3000" GOTRUE_EXTERNAL_EMAIL_ENABLED="true" GOTRUE_EXTERNAL_PHONE_ENABLED="true" -GOTRUE_EXTERNAL_IOS_BUNDLE_ID="com.supabase.gotrue" +GOTRUE_EXTERNAL_IOS_BUNDLE_ID="com.supabase.auth" -# Whitelist redirect to URLs here -GOTRUE_URI_ALLOW_LIST=["http://localhost:3000"] +# Whitelist redirect to URLs here, a comma separated list of URIs (e.g. "https://foo.example.com,https://*.foo.example.com,https://bar.example.com") +GOTRUE_URI_ALLOW_LIST="http://localhost:3000" # Apple OAuth config GOTRUE_EXTERNAL_APPLE_ENABLED="false" @@ -78,6 +78,18 @@ GOTRUE_EXTERNAL_DISCORD_CLIENT_ID="" GOTRUE_EXTERNAL_DISCORD_SECRET="" GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI="https://localhost:9999/callback" +# Facebook OAuth config +GOTRUE_EXTERNAL_FACEBOOK_ENABLED="false" +GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID="" +GOTRUE_EXTERNAL_FACEBOOK_SECRET="" +GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI="https://localhost:9999/callback" + +# Figma OAuth config +GOTRUE_EXTERNAL_FIGMA_ENABLED="false" +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID="" +GOTRUE_EXTERNAL_FIGMA_SECRET="" +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI="https://localhost:9999/callback" + # Gitlab OAuth config GOTRUE_EXTERNAL_GITLAB_ENABLED="false" GOTRUE_EXTERNAL_GITLAB_CLIENT_ID="" @@ -96,11 +108,11 @@ GOTRUE_EXTERNAL_GITHUB_CLIENT_ID="" GOTRUE_EXTERNAL_GITHUB_SECRET="" GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI="http://localhost:9999/callback" -# Facebook OAuth config -GOTRUE_EXTERNAL_FACEBOOK_ENABLED="false" -GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID="" -GOTRUE_EXTERNAL_FACEBOOK_SECRET="" -GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI="https://localhost:9999/callback" +# Kakao OAuth config +GOTRUE_EXTERNAL_KAKAO_ENABLED="false" +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID="" +GOTRUE_EXTERNAL_KAKAO_SECRET="" +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI="http://localhost:9999/callback" # Notion OAuth config GOTRUE_EXTERNAL_NOTION_ENABLED="false" @@ -126,6 +138,13 @@ GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID="" GOTRUE_EXTERNAL_SPOTIFY_SECRET="" GOTRUE_EXTERNAL_SPOTIFY_REDIRECT_URI="http://localhost:9999/callback" +# Keycloak OAuth config +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED="false" +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID="" +GOTRUE_EXTERNAL_KEYCLOAK_SECRET="" +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI="http://localhost:9999/callback" +GOTRUE_EXTERNAL_KEYCLOAK_URL="https://keycloak.example.com/auth/realms/myrealm" + # Linkedin OAuth config GOTRUE_EXTERNAL_LINKEDIN_ENABLED="true" GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID="" @@ -135,7 +154,29 @@ GOTRUE_EXTERNAL_LINKEDIN_SECRET="" GOTRUE_EXTERNAL_SLACK_ENABLED="false" GOTRUE_EXTERNAL_SLACK_CLIENT_ID="" GOTRUE_EXTERNAL_SLACK_SECRET="" -GOTRUE_EXTERNAL_SLACK_REDIRECT_URI="https://localhost:9999/callback" +GOTRUE_EXTERNAL_SLACK_REDIRECT_URI="http://localhost:9999/callback" + +# WorkOS OAuth config +GOTRUE_EXTERNAL_WORKOS_ENABLED="true" +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID="" +GOTRUE_EXTERNAL_WORKOS_SECRET="" +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI="http://localhost:9999/callback" + +# Zoom OAuth config +GOTRUE_EXTERNAL_ZOOM_ENABLED="false" +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID="" +GOTRUE_EXTERNAL_ZOOM_SECRET="" +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback" + +# Web3 Solana config +GOTRUE_EXTERNAL_WEB3_SOLANA_ENABLED="true" +GOTRUE_EXTERNAL_WEB3_SOLANA_MAXIMUM_VALIDITY_DURATION="10m" + +# Anonymous auth config +GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED="false" + +# PKCE Config +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" # Phone provider config GOTRUE_SMS_AUTOCONFIRM="false" @@ -159,6 +200,7 @@ GOTRUE_SMS_VONAGE_FROM="" GOTRUE_SECURITY_CAPTCHA_ENABLED="false" GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" GOTRUE_SESSION_KEY="" # SAML config @@ -171,18 +213,30 @@ GOTRUE_EXTERNAL_SAML_SIGNING_KEY="" # Additional Security config GOTRUE_LOG_LEVEL="debug" -GOTRUE_REFRESH_TOKEN_ROTATION_ENABLED="false" +GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false" +GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0" +GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false" GOTRUE_OPERATOR_TOKEN="unused-operator-token" GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For" GOTRUE_RATE_LIMIT_EMAIL_SENT="100" -# Webhook config -GOTRUE_WEBHOOK_URL=http://register-lambda:3000/ -GOTRUE_WEBHOOK_SECRET=test_secret -GOTRUE_WEBHOOK_RETRIES=5 -GOTRUE_WEBHOOK_TIMEOUT_SEC=3 -GOTRUE_WEBHOOK_EVENTS=validate,signup,login +GOTRUE_MAX_VERIFIED_FACTORS=10 + +# Auth Hook Configuration +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED=false +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRET="" + +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_ENABLED=false +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_SECRET="" + + +# Test OTP Config +GOTRUE_SMS_TEST_OTP=":, :..." +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="" # (e.g. 2023-09-29T08:14:06Z) -# Cookie config -GOTRUE_COOKIE_KEY: "sb" -GOTRUE_COOKIE_DOMAIN: "localhost" \ No newline at end of file +GOTRUE_MFA_WEB_AUTHN_ENROLL_ENABLED="false" +GOTRUE_MFA_WEB_AUTHN_VERIFY_ENABLED="false" diff --git a/go.mod b/go.mod index 1563512c6..f2ecc55b1 100644 --- a/go.mod +++ b/go.mod @@ -1,57 +1,177 @@ -module github.com/netlify/gotrue +module github.com/supabase/auth require ( - cloud.google.com/go v0.67.0 // indirect - github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170623214735-571947b0f240 github.com/Masterminds/semver/v3 v3.1.1 // indirect + github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 + github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 - github.com/beevik/etree v1.1.0 - github.com/coreos/go-oidc/v3 v3.0.0 + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc + github.com/coreos/go-oidc/v3 v3.6.0 github.com/didip/tollbooth/v5 v5.1.1 - github.com/fatih/color v1.10.0 // indirect - github.com/go-chi/chi v4.0.2+incompatible - github.com/go-sql-driver/mysql v1.5.0 - github.com/gobuffalo/envy v1.9.0 // indirect - github.com/gobuffalo/fizz v1.13.0 // indirect - github.com/gobuffalo/flect v0.2.2 // indirect - github.com/gobuffalo/nulls v0.4.0 // indirect - github.com/gobuffalo/packr/v2 v2.8.1 // indirect - github.com/gobuffalo/plush/v4 v4.1.0 // indirect - github.com/gobuffalo/pop/v5 v5.3.3 - github.com/gobuffalo/validate/v3 v3.3.0 // indirect - github.com/gofrs/uuid v4.0.0+incompatible - github.com/golang-jwt/jwt v3.2.1+incompatible - github.com/gorilla/securecookie v1.1.1 - github.com/gorilla/sessions v1.1.1 - github.com/imdario/mergo v0.0.0-20160216103600-3e95a51e0639 - github.com/jackc/pgproto3/v2 v2.0.7 // indirect - github.com/jmoiron/sqlx v1.3.1 // indirect - github.com/joho/godotenv v1.3.0 + github.com/gobuffalo/validate/v3 v3.3.3 // indirect + github.com/gobwas/glob v0.2.3 + github.com/gofrs/uuid v4.3.1+incompatible + github.com/jackc/pgconn v1.14.3 + github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451 + github.com/jackc/pgproto3/v2 v2.3.3 // indirect + github.com/jmoiron/sqlx v1.3.5 + github.com/joho/godotenv v1.4.0 github.com/kelseyhightower/envconfig v1.4.0 - github.com/lestrrat-go/jwx v0.9.0 - github.com/lib/pq v1.9.0 // indirect - github.com/microcosm-cc/bluemonday v1.0.16 // indirect + github.com/microcosm-cc/bluemonday v1.0.26 // indirect + github.com/mitchellh/mapstructure v1.5.0 github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450 - github.com/netlify/mailme v1.1.1 - github.com/opentracing/opentracing-go v1.1.0 - github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pkg/errors v0.9.1 - github.com/rs/cors v1.6.0 - github.com/russellhaering/gosaml2 v0.6.1-0.20210916051624-757d23f1bc28 - github.com/russellhaering/goxmldsig v1.1.1 + github.com/pquerna/otp v1.4.0 + github.com/rs/cors v1.11.0 github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 github.com/sethvargo/go-password v0.2.0 - github.com/sirupsen/logrus v1.7.0 - github.com/spf13/cobra v1.1.3 - github.com/stretchr/testify v1.6.1 - golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad - golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba // indirect - golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 - golang.org/x/sync v0.0.0-20201207232520-09787c993a3a // indirect - golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 // indirect - gopkg.in/DataDog/dd-trace-go.v1 v1.12.1 + github.com/sirupsen/logrus v1.9.3 + github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.35.0 + golang.org/x/oauth2 v0.17.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df - gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 // indirect ) -go 1.13 +require ( + github.com/bits-and-blooms/bitset v1.13.0 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect + github.com/getkin/kin-openapi v0.128.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.4 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-webauthn/x v0.1.12 // indirect + github.com/gobuffalo/nulls v0.4.2 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/google/go-tpm v0.9.1 // indirect + github.com/invopop/yaml v0.3.1 // indirect + github.com/jackc/pgx/v4 v4.18.2 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.5 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/segmentio/asm v1.2.0 // indirect + github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect + github.com/vmware-labs/yaml-jsonpath v0.3.2 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect +) + +require ( + github.com/XSAM/otelsql v0.26.0 + github.com/bombsimon/logrusr/v3 v3.0.0 + go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0 + go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 + go.opentelemetry.io/otel/metric v1.26.0 + go.opentelemetry.io/otel/sdk v1.26.0 + go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.opentelemetry.io/otel/trace v1.26.0 + gopkg.in/h2non/gock.v1 v1.1.2 +) + +require ( + github.com/bits-and-blooms/bloom/v3 v3.6.0 + github.com/btcsuite/btcutil v1.0.2 + github.com/crewjam/saml v0.4.14 + github.com/fatih/structs v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 + github.com/go-chi/chi/v5 v5.0.12 + github.com/go-webauthn/webauthn v0.11.1 + github.com/gobuffalo/pop/v6 v6.1.1 + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/lestrrat-go/jwx/v2 v2.1.0 + github.com/oapi-codegen/oapi-codegen/v2 v2.4.2-0.20250102212541-8bbe226927c9 + github.com/oapi-codegen/runtime v1.1.1 + github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 + github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 + github.com/xeipuuv/gojsonschema v1.2.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0 + go.opentelemetry.io/otel/exporters/prometheus v0.48.0 +) + +require ( + github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/aymerick/douceur v0.2.0 // indirect + github.com/beevik/etree v1.1.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/crewjam/httperr v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.16.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/gobuffalo/envy v1.10.2 // indirect + github.com/gobuffalo/fizz v1.14.4 // indirect + github.com/gobuffalo/flect v1.0.2 // indirect + github.com/gobuffalo/github_flavored_markdown v1.1.3 // indirect + github.com/gobuffalo/helpers v0.6.7 // indirect + github.com/gobuffalo/plush/v4 v4.1.18 // indirect + github.com/gobuffalo/tags/v3 v3.1.4 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/css v1.0.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect + github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/luna-duclos/instrumentedsql v1.1.3 // indirect + github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.19.0 + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.48.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/russellhaering/goxmldsig v1.3.0 // indirect + github.com/sergi/go-diff v1.2.0 // indirect + github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect + github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect + go.opentelemetry.io/proto/otlp v1.2.0 // indirect + golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb + golang.org/x/net v0.36.0 // indirect + golang.org/x/sync v0.11.0 + golang.org/x/sys v0.30.0 + golang.org/x/text v0.22.0 // indirect + golang.org/x/time v0.5.0 + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/grpc v1.63.2 // indirect + google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +go 1.23.7 diff --git a/go.sum b/go.sum index 6ca10922b..9d75b1423 100644 --- a/go.sum +++ b/go.sum @@ -1,278 +1,194 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go v0.67.0 h1:YIkzmqUfVGiGPpT98L8sVvUIkDno6UlrDxw4NR6z5ak= -cloud.google.com/go v0.67.0/go.mod h1:YNan/mUhNZFrYUor0vqrsQ0Ffl7Xtm/ACOy/vsTS858= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= -github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170623214735-571947b0f240 h1:bCOIpv1VinSRhS5ezZeCEGG82gib2WtXfiJOHmMSuls= -github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170623214735-571947b0f240/go.mod h1:aJ4qN3TfrelA6NZ6AXsXRfmEVaYin3EDbSPJrKS8OXo= -github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878/go.mod h1:3AMJUQhVx52RsWOnlkpikZr01T/yAVN2gn0861vByNg= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= +github.com/XSAM/otelsql v0.26.0 h1:UhAGVBD34Ctbh2aYcm/JAdL+6T6ybrP+YMWYkHqCdmo= +github.com/XSAM/otelsql v0.26.0/go.mod h1:5ciw61eMSh+RtTPN8spvPEPLJpAErZw8mFFPNfYiaxA= +github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 h1:uz4N2yHL4MF8vZX+36n+tcxeUf8D/gL4aJkyouhDw4A= +github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40/go.mod h1:dytw+5qs+pdi61fO/S4OmXR7AuEq/HvNCuG03KxQHT4= +github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= +github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= +github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= +github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 h1:vMqcPzLT1/mbYew0gM6EJy4/sCNy9lY9rmlFO+pPwhY= github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62/go.mod h1:r5ZalvRl3tXevRNJkwIB6DC4DD3DMjIlY9NEU1XGoaQ= github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= -github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= -github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/bugsnag/bugsnag-go v1.5.3/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= -github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.6.0 h1:dTU0OVLJSoOhz9m68FTXMFfA39nR8U/nTCs1zb26mOI= +github.com/bits-and-blooms/bloom/v3 v3.6.0/go.mod h1:VKlUSvp0lFIYqxJjzdnSsZEw4iHb1kOL2tfHTgyJBHg= +github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/bombsimon/logrusr/v3 v3.0.0 h1:tcAoLfuAhKP9npBxWzSdpsvKPQt1XV02nSf2lZA82TQ= +github.com/bombsimon/logrusr/v3 v3.0.0/go.mod h1:PksPPgSFEL2I52pla2glgCyyd2OqOHAnFF5E+g8Ixco= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= +github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= +github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= +github.com/btcsuite/btcutil v1.0.2 h1:9iZ1Terx9fMIOtq1VrwdqfsATL9MC2l8ZrUY6YZ2uts= +github.com/btcsuite/btcutil v1.0.2/go.mod h1:j9HUFwoQRsZL3V4n+qG+CUnEGHOarIxfC3Le2Yhbcts= +github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= +github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY= +github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc= +github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= +github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= -github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-oidc/v3 v3.0.0 h1:/mAA0XMgYJw2Uqm7WKGCsKnjitE/+A0FFbOmiRJm7LQ= -github.com/coreos/go-oidc/v3 v3.0.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o= +github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crewjam/httperr v0.2.0 h1:b2BfXR8U3AlIHwNeFFvZ+BV1LFvKLlzMjzaTnZMybNo= +github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3pglZ5oH4= +github.com/crewjam/saml v0.4.14 h1:g9FBNx62osKusnFzs3QTN5L9CVA/Egfgm+stJShzw/c= +github.com/crewjam/saml v0.4.14/go.mod h1:UVSZCf18jJkk6GpWNVqcyQJMD5HsRugBPf4I1nl2mME= +github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/didip/tollbooth/v5 v5.1.1 h1:QpKFg56jsbNuQ6FFj++Z1gn2fbBsvAc1ZPLUaDOYW5k= github.com/didip/tollbooth/v5 v5.1.1/go.mod h1:d9rzwOULswrD3YIrAQmP3bfjxab32Df4IaO6+D25l9g= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= -github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= +github.com/dprotaso/go-yit v0.0.0-20191028211022-135eb7262960/go.mod h1:9HQzr9D/0PGwMEbC3d5AB7oi67+h4TsQqItC1GVYG58= +github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 h1:PRxIJD8XjimM5aTknUK9w6DHLDox2r2M3DI4i2pnd3w= +github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936/go.mod h1:ttYvX5qlB+mlV1okblJqcSMtR4c52UKxDiX9GRBS8+Q= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= -github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs= -github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/getkin/kin-openapi v0.128.0 h1:jqq3D9vC9pPq1dGcOCv7yOp1DaEe7c/T1vzcLbITSp4= +github.com/getkin/kin-openapi v0.128.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM= +github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= +github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= +github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gobuffalo/attrs v0.1.0/go.mod h1:fmNpaWyHM0tRm8gCZWKx8yY9fvaNLo2PyzBNSrBZ5Hw= -github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= -github.com/gobuffalo/envy v1.7.1/go.mod h1:FurDp9+EDPE4aIUS3ZLyD+7/9fpx7YRt/ukY6jIHf0w= -github.com/gobuffalo/envy v1.8.1/go.mod h1:FurDp9+EDPE4aIUS3ZLyD+7/9fpx7YRt/ukY6jIHf0w= -github.com/gobuffalo/envy v1.9.0 h1:eZR0DuEgVLfeIb1zIKt3bT4YovIMf9O9LXQeCZLXpqE= -github.com/gobuffalo/envy v1.9.0/go.mod h1:FurDp9+EDPE4aIUS3ZLyD+7/9fpx7YRt/ukY6jIHf0w= -github.com/gobuffalo/fizz v1.10.0/go.mod h1:J2XGPO0AfJ1zKw7+2BA+6FEGAkyEsdCOLvN93WCT2WI= -github.com/gobuffalo/fizz v1.13.0 h1:XzcBh8DLZH2BgEH77p6q+EKbd8FZyyUXgokUmKXk5ow= -github.com/gobuffalo/fizz v1.13.0/go.mod h1:cXLjhE5p3iuIes6AGZ/9+dfyOkehlB2Vldj0Iw2Uu38= -github.com/gobuffalo/flect v0.1.5/go.mod h1:W3K3X9ksuZfir8f/LrfVtWmCDQFfayuylOJ7sz/Fj80= -github.com/gobuffalo/flect v0.2.0/go.mod h1:W3K3X9ksuZfir8f/LrfVtWmCDQFfayuylOJ7sz/Fj80= -github.com/gobuffalo/flect v0.2.1/go.mod h1:vmkQwuZYhN5Pc4ljYQZzP+1sq+NEkK+lh20jmEmX3jc= -github.com/gobuffalo/flect v0.2.2 h1:PAVD7sp0KOdfswjAw9BpLCU9hXo7wFSzgpQ+zNeks/A= -github.com/gobuffalo/flect v0.2.2/go.mod h1:vmkQwuZYhN5Pc4ljYQZzP+1sq+NEkK+lh20jmEmX3jc= -github.com/gobuffalo/genny/v2 v2.0.5/go.mod h1:kRkJuAw9mdI37AiEYjV4Dl+TgkBDYf8HZVjLkqe5eBg= -github.com/gobuffalo/github_flavored_markdown v1.1.0 h1:8Zzj4fTRl/OP2R7sGerzSf6g2nEJnaBEJe7UAOiEvbQ= -github.com/gobuffalo/github_flavored_markdown v1.1.0/go.mod h1:TSpTKWcRTI0+v7W3x8dkSKMLJSUpuVitlptCkpeY8ic= -github.com/gobuffalo/helpers v0.6.0/go.mod h1:pncVrer7x/KRvnL5aJABLAuT/RhKRR9klL6dkUOhyv8= -github.com/gobuffalo/helpers v0.6.1 h1:LLcL4BsiyDQYtMRUUpyFdBFvFXQ6hNYOpwrcYeilVWM= -github.com/gobuffalo/helpers v0.6.1/go.mod h1:wInbDi0vTJKZBviURTLRMFLE4+nF2uRuuL2fnlYo7w4= -github.com/gobuffalo/logger v1.0.1/go.mod h1:2zbswyIUa45I+c+FLXuWl9zSWEiVuthsk8ze5s8JvPs= -github.com/gobuffalo/logger v1.0.3 h1:YaXOTHNPCvkqqA7w05A4v0k2tCdpr+sgFlgINbQ6gqc= -github.com/gobuffalo/logger v1.0.3/go.mod h1:SoeejUwldiS7ZsyCBphOGURmWdwUFXs0J7TCjEhjKxM= -github.com/gobuffalo/nulls v0.2.0/go.mod h1:w4q8RoSCEt87Q0K0sRIZWYeIxkxog5mh3eN3C/n+dUc= -github.com/gobuffalo/nulls v0.4.0 h1:xi+JHGWIetYqLmS520dSWc8Ifj1P0aNXKTVDMVsPXmw= -github.com/gobuffalo/nulls v0.4.0/go.mod h1:2KmsoLnMrxpwPLN5LmBbm6tmttHSIZr/v/OdGsATM3M= -github.com/gobuffalo/packd v0.3.0/go.mod h1:zC7QkmNkYVGKPw4tHpBQ+ml7W/3tIebgeo1b36chA3Q= -github.com/gobuffalo/packd v1.0.0 h1:6ERZvJHfe24rfFmA9OaoKBdC7+c9sydrytMg8SdFGBM= -github.com/gobuffalo/packd v1.0.0/go.mod h1:6VTc4htmJRFB7u1m/4LeMTWjFoYrUiBkU9Fdec9hrhI= -github.com/gobuffalo/packr/v2 v2.7.1/go.mod h1:qYEvAazPaVxy7Y7KR0W8qYEE+RymX74kETFqjFoFlOc= -github.com/gobuffalo/packr/v2 v2.8.0/go.mod h1:PDk2k3vGevNE3SwVyVRgQCCXETC9SaONCNSXT1Q8M1g= -github.com/gobuffalo/packr/v2 v2.8.1 h1:tkQpju6i3EtMXJ9uoF5GT6kB+LMTimDWD8Xvbz6zDVA= -github.com/gobuffalo/packr/v2 v2.8.1/go.mod h1:c/PLlOuTU+p3SybaJATW3H6lX/iK7xEz5OeMf+NnJpg= -github.com/gobuffalo/plush/v4 v4.0.0/go.mod h1:ErFS3UxKqEb8fpFJT7lYErfN/Nw6vHGiDMTjxpk5bQ0= -github.com/gobuffalo/plush/v4 v4.1.0 h1:l39qv4p16mEZIPQX/YyCgkUsYTe/L9otgCoBMlYb4Ms= -github.com/gobuffalo/plush/v4 v4.1.0/go.mod h1:ErFS3UxKqEb8fpFJT7lYErfN/Nw6vHGiDMTjxpk5bQ0= -github.com/gobuffalo/pop/v5 v5.2.0/go.mod h1:Hj586Cr7FoTFNmvzyNdUcajv3r0A+W+bkil4RIX/zKo= -github.com/gobuffalo/pop/v5 v5.3.3 h1:L8TRyREUSO2Jtai3DeqnPTHV0AOAMZYyf6TaVsLBVsc= -github.com/gobuffalo/pop/v5 v5.3.3/go.mod h1:Ey1hqzDLkWQKNEfsnafaz+3P1h/TrS++W9PmpGsNxvk= -github.com/gobuffalo/tags/v3 v3.0.2/go.mod h1:ZQeN6TCTiwAFnS0dNcbDtSgZDwNKSpqajvVtt6mlYpA= -github.com/gobuffalo/tags/v3 v3.1.0 h1:mzdCYooN2VsLRr8KIAdEZ1lh1Py7JSMsiEGCGata2AQ= -github.com/gobuffalo/tags/v3 v3.1.0/go.mod h1:ZQeN6TCTiwAFnS0dNcbDtSgZDwNKSpqajvVtt6mlYpA= -github.com/gobuffalo/validate/v3 v3.0.0/go.mod h1:HFpjq+AIiA2RHoQnQVTFKF/ZpUPXwyw82LgyDPxQ9r0= -github.com/gobuffalo/validate/v3 v3.1.0/go.mod h1:HFpjq+AIiA2RHoQnQVTFKF/ZpUPXwyw82LgyDPxQ9r0= -github.com/gobuffalo/validate/v3 v3.3.0 h1:j++FFx9gtjTmIQeI9xlaIDZ0nV4x8YQZz4RJAlZNUxg= -github.com/gobuffalo/validate/v3 v3.3.0/go.mod h1:HFpjq+AIiA2RHoQnQVTFKF/ZpUPXwyw82LgyDPxQ9r0= -github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/go-webauthn/webauthn v0.11.1 h1:5G/+dg91/VcaJHTtJUfwIlNJkLwbJCcnUc4W8VtkpzA= +github.com/go-webauthn/webauthn v0.11.1/go.mod h1:YXRm1WG0OtUyDFaVAgB5KG7kVqW+6dYCJ7FTQH4SxEE= +github.com/go-webauthn/x v0.1.12 h1:RjQ5cvApzyU/xLCiP+rub0PE4HBZsLggbxGR5ZpUf/A= +github.com/go-webauthn/x v0.1.12/go.mod h1:XlRcGkNH8PT45TfeJYc6gqpOtiOendHhVmnOxh+5yHs= +github.com/gobuffalo/attrs v1.0.3/go.mod h1:KvDJCE0avbufqS0Bw3UV7RQynESY0jjod+572ctX4t8= +github.com/gobuffalo/envy v1.10.2 h1:EIi03p9c3yeuRCFPOKcSfajzkLb3hrRjEpHGI8I2Wo4= +github.com/gobuffalo/envy v1.10.2/go.mod h1:qGAGwdvDsaEtPhfBzb3o0SfDea8ByGn9j8bKmVft9z8= +github.com/gobuffalo/fizz v1.14.4 h1:8uume7joF6niTNWN582IQ2jhGTUoa9g1fiV/tIoGdBs= +github.com/gobuffalo/fizz v1.14.4/go.mod h1:9/2fGNXNeIFOXEEgTPJwiK63e44RjG+Nc4hfMm1ArGM= +github.com/gobuffalo/flect v0.3.0/go.mod h1:5pf3aGnsvqvCj50AVni7mJJF8ICxGZ8HomberC3pXLE= +github.com/gobuffalo/flect v1.0.0/go.mod h1:l9V6xSb4BlXwsxEMj3FVEub2nkdQjWhPvD8XTTlHPQc= +github.com/gobuffalo/flect v1.0.2 h1:eqjPGSo2WmjgY2XlpGwo2NXgL3RucAKo4k4qQMNA5sA= +github.com/gobuffalo/flect v1.0.2/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= +github.com/gobuffalo/genny/v2 v2.1.0/go.mod h1:4yoTNk4bYuP3BMM6uQKYPvtP6WsXFGm2w2EFYZdRls8= +github.com/gobuffalo/github_flavored_markdown v1.1.3 h1:rSMPtx9ePkFB22vJ+dH+m/EUBS8doQ3S8LeEXcdwZHk= +github.com/gobuffalo/github_flavored_markdown v1.1.3/go.mod h1:IzgO5xS6hqkDmUh91BW/+Qxo/qYnvfzoz3A7uLkg77I= +github.com/gobuffalo/helpers v0.6.7 h1:C9CedoRSfgWg2ZoIkVXgjI5kgmSpL34Z3qdnzpfNVd8= +github.com/gobuffalo/helpers v0.6.7/go.mod h1:j0u1iC1VqlCaJEEVkZN8Ia3TEzfj/zoXANqyJExTMTA= +github.com/gobuffalo/logger v1.0.7/go.mod h1:u40u6Bq3VVvaMcy5sRBclD8SXhBYPS0Qk95ubt+1xJM= +github.com/gobuffalo/nulls v0.4.2 h1:GAqBR29R3oPY+WCC7JL9KKk9erchaNuV6unsOSZGQkw= +github.com/gobuffalo/nulls v0.4.2/go.mod h1:EElw2zmBYafU2R9W4Ii1ByIj177wA/pc0JdjtD0EsH8= +github.com/gobuffalo/packd v1.0.2/go.mod h1:sUc61tDqGMXON80zpKGp92lDb86Km28jfvX7IAyxFT8= +github.com/gobuffalo/plush/v4 v4.1.16/go.mod h1:6t7swVsarJ8qSLw1qyAH/KbrcSTwdun2ASEQkOznakg= +github.com/gobuffalo/plush/v4 v4.1.18 h1:bnPjdMTEUQHqj9TNX2Ck3mxEXYZa+0nrFMNM07kpX9g= +github.com/gobuffalo/plush/v4 v4.1.18/go.mod h1:xi2tJIhFI4UdzIL8sxZtzGYOd2xbBpcFbLZlIPGGZhU= +github.com/gobuffalo/pop/v6 v6.1.1 h1:eUDBaZcb0gYrmFnKwpuTEUA7t5ZHqNfvS4POqJYXDZY= +github.com/gobuffalo/pop/v6 v6.1.1/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI= +github.com/gobuffalo/tags/v3 v3.1.4 h1:X/ydLLPhgXV4h04Hp2xlbI2oc5MDaa7eub6zw8oHjsM= +github.com/gobuffalo/tags/v3 v3.1.4/go.mod h1:ArRNo3ErlHO8BtdA0REaZxijuWnWzF6PUXngmMXd2I0= +github.com/gobuffalo/validate/v3 v3.3.3 h1:o7wkIGSvZBYBd6ChQoLxkz2y1pfmhbI4jNJYh6PuNJ4= +github.com/gobuffalo/validate/v3 v3.3.3/go.mod h1:YC7FsbJ/9hW/VjQdmXPvFqvRis4vrRYFxr69WiNZw6g= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= -github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.3.1+incompatible h1:0/KbAdpx3UXAx1kEOWHJeOkpbgRFGHVgv+CFIY7dBJI= +github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200905233945-acf8798be1f7/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-tpm v0.9.1 h1:0pGc4X//bAlmZzMKf8iz6IsDo1nYTbYJ6FZN/rg4zdM= +github.com/google/go-tpm v0.9.1/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.1.1 h1:YMDmfaK68mUixINzY/XjscuJ47uXFWSSHzFbBQM0PrE= -github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= -github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-hclog v0.9.1/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-msgpack v0.5.5/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= -github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= -github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= -github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= -github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= -github.com/hashicorp/raft v1.1.0/go.mod h1:4Ak7FSPnuvmb0GV6vgIAJ4vYT4bek9bb6Q+7HVbyzqM= -github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/imdario/mergo v0.0.0-20160216103600-3e95a51e0639 h1:VMd01CgpBpmLpuERyY4Oibn2PpcVS1fK9sjh5UZG8+o= -github.com/imdario/mergo v0.0.0-20160216103600-3e95a51e0639/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= -github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/yaml v0.3.1 h1:f0+ZpmhfBSS4MhG+4HYseMdJhoeeopbSKbq5Rpeelso= +github.com/invopop/yaml v0.3.1/go.mod h1:PMOp3nn4/12yEZUFfmOuNHJsZToEEOwoWsT+D81KkeA= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -280,616 +196,454 @@ github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgO github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= -github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.6.0/go.mod h1:yeseQo4xhQbgyJs2c87RAXOH2i624N0Fh1KSPJya7qo= -github.com/jackc/pgconn v1.8.0 h1:FmjZ0rOyXTr1wfWs45i4a9vjnjWUAGpMuQLD9OSs+lw= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI= +github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= +github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= +github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451 h1:WAvSpGf7MsFuzAtK4Vk7R4EVe+liW4x83r4oWu0WHKw= +github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.7 h1:6Pwi1b3QdY65cuv6SyVO0FgPd5J3Bl7wf/nQQjinHMA= -github.com/jackc/pgproto3/v2 v2.0.7/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.0/go.mod h1:b0JqxHvPmljG+HQ5IsvQ0yqeSi4nGcDTVjFoiLDb0Ik= -github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= -github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= -github.com/jackc/pgtype v1.6.2 h1:b3pDeuhbbzBYcg5kwNmNDun4pFUD/0AAr1kLXZLeNt8= -github.com/jackc/pgtype v1.6.2/go.mod h1:JCULISAZBFGrHaOXIIFiyfzW5VY0GRitRr8NeJsrdig= -github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= -github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= -github.com/jackc/pgx/v4 v4.6.0/go.mod h1:vPh43ZzxijXUVJ+t/EmXBtFmbFVO72cuneCT9oAlxAg= -github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= -github.com/jackc/pgx/v4 v4.10.1 h1:/6Q3ye4myIj6AaplUm+eRcz4OhK9HAvFf4ePsG40LJY= -github.com/jackc/pgx/v4 v4.10.1/go.mod h1:QlrWebbs3kqEZPHCTGyxecvzG6tvIsYu+A5b1raylkA= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.17.2/go.mod h1:lcxIZN44yMIrWI78a5CpucdD14hX0SBDbNRvjDBItsw= +github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= +github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= -github.com/jmoiron/sqlx v1.3.1 h1:aLN7YINNZ7cYOPK3QC83dbM6KT0NMqVMw961TqrejlE= -github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= -github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= -github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= +github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= -github.com/karrick/godirwalk v1.15.3/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk= -github.com/karrick/godirwalk v1.15.8/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk= -github.com/karrick/godirwalk v1.16.1 h1:DynhcF+bztK8gooS0+NDJFrdNZjJ3gzVzC545UNA9iw= -github.com/karrick/godirwalk v1.16.1/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= +github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lestrrat-go/jwx v0.9.0 h1:Fnd0EWzTm0kFrBPzE/PEPp9nzllES5buMkksPMjEKpM= -github.com/lestrrat-go/jwx v0.9.0/go.mod h1:iEoxlYfZjvoGpuWwxUz+eR5e6KTJGsaRcy/YNA/UnBk= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk= +github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.0 h1:0zs7Ya6+39qoit7gwAf+cYm1zzgS3fceIdo7RmQ5lkw= +github.com/lestrrat-go/jwx/v2 v2.1.0/go.mod h1:Xpw9QIaUGiIUD1Wx0NcY1sIHwFf8lDuZn/cmxtXYRys= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= -github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/luna-duclos/instrumentedsql v1.1.3 h1:t7mvC0z1jUt5A0UQ6I/0H31ryymuQRnJcWCiqV3lSAA= github.com/luna-duclos/instrumentedsql v1.1.3/go.mod h1:9J1njvFds+zN7y85EDhN9XNQLANWwZt2ULeIC8yMNYs= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/markbates/errx v1.1.0 h1:QDFeR+UP95dO12JgW+tgi2UVfo0V8YBHiUIOaeBPiEI= -github.com/markbates/errx v1.1.0/go.mod h1:PLa46Oex9KNbVDZhKel8v1OT7hD5JZ2eI7AHhA0wswc= -github.com/markbates/oncer v1.0.0 h1:E83IaVAHygyndzPimgUYJjbshhDTALZyXxvk9FOlQRY= -github.com/markbates/oncer v1.0.0/go.mod h1:Z59JA581E9GP6w96jai+TGqafHPW+cPfRxz2aSZ0mcI= -github.com/markbates/safe v1.0.1 h1:yjZkbvRM6IzKj9tlu/zMJLS0n/V351OZWRnF3QfaUxI= -github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= -github.com/microcosm-cc/bluemonday v1.0.16 h1:kHmAq2t7WPWLjiGvzKa5o3HzSfahUKiOq7fAPUiMNIc= -github.com/microcosm-cc/bluemonday v1.0.16/go.mod h1:Z0r70sCuXHig8YpBzCc5eGHAap2K7e/u082ZUpDRRqM= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= -github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= -github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/microcosm-cc/bluemonday v1.0.20/go.mod h1:yfBmMi8mxvaZut3Yytv+jTXRY8mxyjJ0/kQBTElld50= +github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= +github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450 h1:j2kD3MT1z4PXCiUllUJF9mWUESr9TWKS7iEKsQ/IipM= github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450/go.mod h1:skjdDftzkFALcuGzYSklqYd8gvat6F1gZJ4YPVbkZpM= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/jwt v0.2.6/go.mod h1:mQxQ0uHQ9FhEVPIcTSKwx2lqZEpXWWcCgA7R6NrWvvY= -github.com/nats-io/nats-server/v2 v2.0.0/go.mod h1:RyVdsHHvY4B6c9pWG+uRLpZ0h0XsqiuKp2XCTurP5LI= -github.com/nats-io/nats-streaming-server v0.15.1/go.mod h1:bJ1+2CS8MqvkGfr/NwnCF+Lw6aLnL3F5kenM8bZmdCw= -github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ3M8LwxM= -github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= -github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/nats-io/stan.go v0.4.5/go.mod h1:Ji7mK6gRZJSH1nc3ZJH6vi7zn/QnZhpR9Arm4iuzsUQ= -github.com/nats-io/stan.go v0.5.0/go.mod h1:dYqB+vMN3C2F9pT1FRQpg9eHbjPj6mP0yYuyBNuXHZE= -github.com/netlify/mailme v1.1.1 h1:S/ANl+Hy/EIoJUgGiLJYYLZJ2QOTG452R73qTQudMns= -github.com/netlify/mailme v1.1.1/go.mod h1:8g03BJmU+ps7ma5vcH+t8aMtaicQTMX3ffP7RJ8xY8g= -github.com/netlify/netlify-commons v0.32.0 h1:IgpqedBa6aFrc+daRgGZ+SmU9eBXlDXzKSAjevWmshM= -github.com/netlify/netlify-commons v0.32.0/go.mod h1:xZH7auZrc/N/ZKS9BRO74yNf8i9LitXq1h6JVFZ2jTc= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= +github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/oapi-codegen/oapi-codegen/v2 v2.4.2-0.20250102212541-8bbe226927c9 h1:KXRttm+U6P6gZ5wiOPuAblyxGLEXlT+qjC3vPhe8cg4= +github.com/oapi-codegen/oapi-codegen/v2 v2.4.2-0.20250102212541-8bbe226927c9/go.mod h1:Lzhz8QiRu5FjGuXPT03q6nbgaTZAqidN17pyOKjuXeE= +github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= +github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.10.2/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= +github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= github.com/patrickmn/go-cache v0.0.0-20170418232947-7ac151875ffb/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.3.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= -github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ= -github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= -github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/pquerna/otp v1.4.0 h1:wZvl1TIVxKRThZIBiwOOHOGP/1+nZyWBil9Y2XNEDzg= +github.com/pquerna/otp v1.4.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= +github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.3.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.4.0/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.5.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/rs/cors v1.6.0 h1:G9tHG9lebljV9mfp9SNPDL36nCDxmo3zTlAf1YgvzmI= -github.com/rs/cors v1.6.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/russellhaering/gosaml2 v0.6.1-0.20210916051624-757d23f1bc28 h1:659ZmS9riGgajUnT9ym74yQSug2KZyvVHi3EmIqASnQ= -github.com/russellhaering/gosaml2 v0.6.1-0.20210916051624-757d23f1bc28/go.mod h1:PiLt5KX4EMjlMIq3WLRR/xb5yqhiwtQhGr8wmU0b08M= -github.com/russellhaering/goxmldsig v1.1.1 h1:vI0r2osGF1A9PLvsGdPUAGwEIrKa4Pj5sesSBsebIxM= -github.com/russellhaering/goxmldsig v1.1.1/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 h1:eajwn6K3weW5cd1ZXLu2sJ4pvwlBiCWY4uDejOr73gM= github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= +github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sethvargo/go-password v0.2.0 h1:BTDl4CC/gjf/axHMaDQtw507ogrXLci6XRiLc7i/UHI= github.com/sethvargo/go-password v0.2.0/go.mod h1:Ym4Mr9JXLBycr02MFuVQ/0JHidNetSgbzutTr3zsYXE= -github.com/shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:DmcHeT/UuSDXaCVb8IijmL+fHX+FK9TLy98W7mfDXXg= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc h1:jUIKcSPO9MoMJBbEoyE/RJoE8vz7Mb8AjvifMMwSyvY= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= -github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e h1:qpG93cPwA5f7s/ZPBJnGOYQNK/vKsaDaseuKT5Asee8= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.4-0.20190321000552-67fc4837d267/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/cobra v0.0.6/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= -github.com/spf13/cobra v1.1.3 h1:xghbfqPkxzxP3C/f3n5DdpAbdKLj4ZE4BWQI362l53M= -github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/speakeasy-api/openapi-overlay v0.9.0 h1:Wrz6NO02cNlLzx1fB093lBlYxSI54VRhy1aSutx0PQg= +github.com/speakeasy-api/openapi-overlay v0.9.0/go.mod h1:f5FloQrHA7MsxYg9djzMD5h6dxrHjVVByWKh7an8TRc= +github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= -github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= +github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= -github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU= -github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 h1:VDuRtwen5Z7QQ5ctuHUse4wAv/JozkKZkdic5vUV4Lg= +github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869/go.mod h1:eHX5nlSMSnyPjUrbYzeqrA8snCe2SKyfizKjU3dkfOw= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/vmware-labs/yaml-jsonpath v0.3.2 h1:/5QKeCBGdsInyDCyVNLbXyilb61MXGi9NP674f9Hobk= +github.com/vmware-labs/yaml-jsonpath v0.3.2/go.mod h1:U6whw1z03QyqgWdgXxvVnQ90zN1BWz5V+51Ewf8k+rQ= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= +go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0 h1:2JydY5UiDpqvj2p7sO9bgHuhTy4hgTZ0ymehdq/Ob0Q= +go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0/go.mod h1:ch3a5QxOqVWxas4CzjCFFOOQe+7HgAXC/N1oVxS9DK4= +go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= +go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0 h1:+hm+I+KigBy3M24/h1p/NHkUx/evbLH0PNcjpMyCHc4= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0/go.mod h1:NjC8142mLvvNT6biDpaMjyz78kyEHIwAJlSX0N9P5KI= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0 h1:HGZWGmCVRCVyAs2GQaiHQPbDHo+ObFWeUEOd+zDnp64= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0/go.mod h1:SaH+v38LSCHddyk7RGlU9uZyQoRrKao6IBnJw6Kbn+c= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0 h1:3d+S281UTjM+AbF31XSOYn1qXn3BgIdWl8HNEpx08Jk= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0/go.mod h1:0+KuTDyKL4gjKCF75pHOX4wuzYDUZYfAQdSu43o+Z2I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= +go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= +go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= +go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= +go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= +go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= +go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= +go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= +go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= +go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= +go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IOkz94= +go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= -golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= +golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= +golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20161007143504-f4b625ec9b21/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200219183655-46282727080f/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200927032502-5d4f70055728/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba h1:6u6sik+bn/y7vILcYkK3iwTBWN7WtBvB0+SZswQnbf8= -golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 h1:ld7aEMNHoBnnDAX15v1T6z31v8HwR2A9FYOuAhWqkwc= -golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= +golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= +golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= +golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190515120540-06a5c4944438/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.0.0-20160926182426-711ca1cb8763/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191004055002-72853e10c5a3/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117220505-0cba7a3a9ee9/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200308013534-11ec41452d41/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= -golang.org/x/tools v0.0.0-20200929161345-d7fc70abf50f/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/api v0.32.0 h1:Le77IccnTqEa8ryp9wIpX5W3zYm7Gf9LhOp9PHcwFts= -google.golang.org/api v0.32.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc= -google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200929141702-51c3e5b607fe h1:6SgESkjJknFUnsfQ2yxQbmTAi37BxhwS/riq+VdLo9c= -google.golang.org/genproto v0.0.0-20200929141702-51c3e5b607fe/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= -google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de h1:jFNzHPIeuzhdRwVhbZdiym9q0ory/xY3sA+v2wPg8I0= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda h1:LI5DOvAxUPMv/50agcLLoo+AdWc1irS9Rzz4vPuD1V4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= +google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/DataDog/dd-trace-go.v1 v1.12.1 h1:zkyLw+Uq6BvGwy5hFeLVI1ePgOkqJswFPL1uOx6SSA4= -gopkg.in/DataDog/dd-trace-go.v1 v1.12.1/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -898,34 +652,26 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/gomail.v2 v2.0.0-20150902115704-41f357289737/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= +gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= -gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= -gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 h1:POO/ycCATvegFmVuPpQzZFJ+pGZeX22Ufu6fibxDVjU= -gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20190924164351-c8b7dadae555/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20191026110619-0b21df46bc1d/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/hack/coverage.sh b/hack/coverage.sh new file mode 100755 index 000000000..1eb284dac --- /dev/null +++ b/hack/coverage.sh @@ -0,0 +1,21 @@ +FAIL=false + +for PKG in "crypto" "reloader" "utilities/siws" +do + UNCOVERED_FUNCS=$(go tool cover -func=coverage.out | grep "^github.com/supabase/auth/internal/$PKG/" | grep -v '100.0%$') + UNCOVERED_FUNCS_COUNT=$(echo "$UNCOVERED_FUNCS" | wc -l) + + if [ "$UNCOVERED_FUNCS_COUNT" -gt 1 ] # wc -l counts +1 line + then + echo "Package $PKG not covered 100% with tests. $UNCOVERED_FUNCS_COUNT functions need more tests. This is mandatory." + echo "$UNCOVERED_FUNCS" + FAIL=true + fi +done + +if [ "$FAIL" = "true" ] +then + exit 1 +else + exit 0 +fi diff --git a/hack/init_postgres.sql b/hack/init_postgres.sql index cc250f500..d1ef70994 100644 --- a/hack/init_postgres.sql +++ b/hack/init_postgres.sql @@ -1,115 +1,7 @@ CREATE USER supabase_admin LOGIN CREATEROLE CREATEDB REPLICATION BYPASSRLS; -CREATE SCHEMA IF NOT EXISTS auth AUTHORIZATION supabase_admin; - --- auth.users definition - -CREATE TABLE auth.users ( - instance_id uuid NULL, - id uuid NOT NULL UNIQUE, - aud varchar(255) NULL, - "role" varchar(255) NULL, - email varchar(255) NULL UNIQUE, - encrypted_password varchar(255) NULL, - confirmed_at timestamptz NULL, - invited_at timestamptz NULL, - confirmation_token varchar(255) NULL, - confirmation_sent_at timestamptz NULL, - recovery_token varchar(255) NULL, - recovery_sent_at timestamptz NULL, - email_change_token varchar(255) NULL, - email_change varchar(255) NULL, - email_change_sent_at timestamptz NULL, - last_sign_in_at timestamptz NULL, - raw_app_meta_data jsonb NULL, - raw_user_meta_data jsonb NULL, - is_super_admin bool NULL, - created_at timestamptz NULL, - updated_at timestamptz NULL, - CONSTRAINT users_pkey PRIMARY KEY (id) -); -CREATE INDEX users_instance_id_email_idx ON auth.users USING btree (instance_id, email); -CREATE INDEX users_instance_id_idx ON auth.users USING btree (instance_id); -comment on table auth.users is 'Auth: Stores user login data within a secure schema.'; - --- auth.refresh_tokens definition - -CREATE TABLE auth.refresh_tokens ( - instance_id uuid NULL, - id bigserial NOT NULL, - "token" varchar(255) NULL, - user_id varchar(255) NULL, - revoked bool NULL, - created_at timestamptz NULL, - updated_at timestamptz NULL, - CONSTRAINT refresh_tokens_pkey PRIMARY KEY (id) -); -CREATE INDEX refresh_tokens_instance_id_idx ON auth.refresh_tokens USING btree (instance_id); -CREATE INDEX refresh_tokens_instance_id_user_id_idx ON auth.refresh_tokens USING btree (instance_id, user_id); -CREATE INDEX refresh_tokens_token_idx ON auth.refresh_tokens USING btree (token); -comment on table auth.refresh_tokens is 'Auth: Store of tokens used to refresh JWT tokens once they expire.'; - --- auth.instances definition - -CREATE TABLE auth.instances ( - id uuid NOT NULL, - uuid uuid NULL, - raw_base_config text NULL, - created_at timestamptz NULL, - updated_at timestamptz NULL, - CONSTRAINT instances_pkey PRIMARY KEY (id) -); -comment on table auth.instances is 'Auth: Manages users across multiple sites.'; - --- auth.audit_log_entries definition - -CREATE TABLE auth.audit_log_entries ( - instance_id uuid NULL, - id uuid NOT NULL, - payload json NULL, - created_at timestamptz NULL, - CONSTRAINT audit_log_entries_pkey PRIMARY KEY (id) -); -CREATE INDEX audit_logs_instance_id_idx ON auth.audit_log_entries USING btree (instance_id); -comment on table auth.audit_log_entries is 'Auth: Audit trail for user actions.'; - --- auth.schema_migrations definition - -CREATE TABLE auth.schema_migrations ( - "version" varchar(255) NOT NULL, - CONSTRAINT schema_migrations_pkey PRIMARY KEY ("version") -); -comment on table auth.schema_migrations is 'Auth: Manages updates to the auth system.'; - -INSERT INTO auth.schema_migrations (version) -VALUES ('20171026211738'), - ('20171026211808'), - ('20171026211834'), - ('20180103212743'), - ('20180108183307'), - ('20180119214651'), - ('20180125194653'); - --- Gets the User ID from the request cookie -create or replace function auth.uid() returns uuid as $$ - select nullif(current_setting('request.jwt.claim.sub', true), '')::uuid; -$$ language sql stable; - --- Gets the User ID from the request cookie -create or replace function auth.role() returns text as $$ - select nullif(current_setting('request.jwt.claim.role', true), '')::text; -$$ language sql stable; -- Supabase super admin -CREATE USER supabase_auth_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION; -GRANT ALL PRIVILEGES ON SCHEMA auth TO supabase_auth_admin; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA auth TO supabase_auth_admin; -GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA auth TO supabase_auth_admin; -ALTER USER supabase_auth_admin SET search_path = "auth"; -ALTER USER supabase_auth_admin with password 'root'; -ALTER table "auth".users OWNER TO supabase_auth_admin; -ALTER table "auth".refresh_tokens OWNER TO supabase_auth_admin; -ALTER table "auth".audit_log_entries OWNER TO supabase_auth_admin; -ALTER table "auth".instances OWNER TO supabase_auth_admin; -ALTER table "auth".schema_migrations OWNER TO supabase_auth_admin; -ALTER FUNCTION auth.uid OWNER TO supabase_auth_admin; -ALTER FUNCTION auth.role OWNER TO supabase_auth_admin; +CREATE USER supabase_auth_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION PASSWORD 'root'; +CREATE SCHEMA IF NOT EXISTS auth AUTHORIZATION supabase_auth_admin; +GRANT CREATE ON DATABASE postgres TO supabase_auth_admin; +ALTER USER supabase_auth_admin SET search_path = 'auth'; diff --git a/hack/migrate.sh b/hack/migrate.sh index 06c55f972..2d1f0e5e8 100755 --- a/hack/migrate.sh +++ b/hack/migrate.sh @@ -9,7 +9,4 @@ export GOTRUE_DB_DRIVER="postgres" export GOTRUE_DB_DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/$DB_ENV" export GOTRUE_DB_MIGRATIONS_PATH=$DIR/../migrations -echo soda -v -soda drop -d -e $DB_ENV -c $DATABASE -soda create -d -e $DB_ENV -c $DATABASE go run main.go migrate -c $DIR/test.env diff --git a/hack/migrate_postgres.sh b/hack/migrate_postgres.sh deleted file mode 100755 index 59a8e80c2..000000000 --- a/hack/migrate_postgres.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash - -DB_ENV=$1 - -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -DATABASE="$DIR/database.yml" - -export GOTRUE_DB_DRIVER="postgres" -export GOTRUE_DB_DATABASE_URL="postgres://postgres:root@localhost:5432/$DB_ENV?sslmode=disable" -export GOTRUE_DB_MIGRATIONS_PATH=$DIR/../migrations - -go run main.go migrate -c $DIR/test.env diff --git a/hack/mysqld.sh b/hack/mysqld.sh deleted file mode 100755 index 5e47c1b75..000000000 --- a/hack/mysqld.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash - -docker rm -f gotrue_mysql >/dev/null 2>/dev/null || true - -docker volume inspect mysql_data 2>/dev/null >/dev/null || docker volume create --name mysql_data >/dev/null - -docker run --name gotrue_mysql \ - -p 3306:3306 \ - -e MYSQL_ALLOW_EMPTY_PASSWORD=yes \ - --volume mysql_data:/var/lib/mysql \ - -d mysql:5.7 mysqld --bind-address=0.0.0.0 diff --git a/hack/postgresd.sh b/hack/postgresd.sh index 3d6ee6581..c4b6a58e3 100755 --- a/hack/postgresd.sh +++ b/hack/postgresd.sh @@ -11,4 +11,4 @@ docker run --name gotrue_postgresql \ -e POSTGRES_DB=postgres \ --volume postgres_data:/var/lib/postgresql/data \ --volume "$(pwd)"/hack/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql \ - -d postgres:13 + -d postgres:15 diff --git a/hack/test.env b/hack/test.env index fb1c56c3a..9ed4c3d58 100644 --- a/hack/test.env +++ b/hack/test.env @@ -9,9 +9,11 @@ GOTRUE_DB_AUTOMIGRATE=true DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" GOTRUE_API_HOST=localhost PORT=9999 -GOTRUE_LOG_LEVEL=debug +API_EXTERNAL_URL="http://localhost:9999" +GOTRUE_LOG_SQL=none +GOTRUE_LOG_LEVEL=warn GOTRUE_SITE_URL=https://example.netlify.com -GOTRUE_URI_ALLOW_LIST="http://localhost:3000" +GOTRUE_URI_ALLOW_LIST="http://localhost:3000,https://supabase.com/" GOTRUE_OPERATOR_TOKEN=foobar GOTRUE_EXTERNAL_APPLE_ENABLED=true GOTRUE_EXTERNAL_APPLE_CLIENT_ID=testclientid @@ -33,14 +35,35 @@ GOTRUE_EXTERNAL_FACEBOOK_ENABLED=true GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID=testclientid GOTRUE_EXTERNAL_FACEBOOK_SECRET=testsecret GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FLY_ENABLED=true +GOTRUE_EXTERNAL_FLY_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_FLY_SECRET=testsecret +GOTRUE_EXTERNAL_FLY_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FIGMA_ENABLED=true +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_FIGMA_SECRET=testsecret +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI=https://identity.services.netlify.com/callback GOTRUE_EXTERNAL_GITHUB_ENABLED=true GOTRUE_EXTERNAL_GITHUB_CLIENT_ID=testclientid GOTRUE_EXTERNAL_GITHUB_SECRET=testsecret GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KAKAO_ENABLED=true +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_KAKAO_SECRET=testsecret +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED=true +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_KEYCLOAK_SECRET=testsecret +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KEYCLOAK_URL=https://keycloak.example.com/auth/realms/myrealm GOTRUE_EXTERNAL_LINKEDIN_ENABLED=true GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID=testclientid GOTRUE_EXTERNAL_LINKEDIN_SECRET=testsecret GOTRUE_EXTERNAL_LINKEDIN_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_LINKEDIN_OIDC_ENABLED=true +GOTRUE_EXTERNAL_LINKEDIN_OIDC_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_LINKEDIN_OIDC_SECRET=testsecret +GOTRUE_EXTERNAL_LINKEDIN_OIDC_REDIRECT_URI=https://identity.services.netlify.com/callback GOTRUE_EXTERNAL_GITLAB_ENABLED=true GOTRUE_EXTERNAL_GITLAB_CLIENT_ID=testclientid GOTRUE_EXTERNAL_GITLAB_SECRET=testsecret @@ -61,6 +84,14 @@ GOTRUE_EXTERNAL_SLACK_ENABLED=true GOTRUE_EXTERNAL_SLACK_CLIENT_ID=testclientid GOTRUE_EXTERNAL_SLACK_SECRET=testsecret GOTRUE_EXTERNAL_SLACK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_SLACK_OIDC_ENABLED=true +GOTRUE_EXTERNAL_SLACK_OIDC_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_SLACK_OIDC_SECRET=testsecret +GOTRUE_EXTERNAL_SLACK_OIDC_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_WORKOS_ENABLED=true +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_WORKOS_SECRET=testsecret +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI=https://identity.services.netlify.com/callback GOTRUE_EXTERNAL_TWITCH_ENABLED=true GOTRUE_EXTERNAL_TWITCH_CLIENT_ID=testclientid GOTRUE_EXTERNAL_TWITCH_SECRET=testsecret @@ -69,14 +100,30 @@ GOTRUE_EXTERNAL_TWITTER_ENABLED=true GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback -GOTRUE_EXTERNAL_SAML_ENABLED=true -GOTRUE_EXTERNAL_SAML_METADATA_URL= -GOTRUE_EXTERNAL_SAML_API_BASE=http://localhost -GOTRUE_EXTERNAL_SAML_NAME=TestSamlName -GOTRUE_TRACING_ENABLED=false +GOTRUE_EXTERNAL_ZOOM_ENABLED=true +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" +GOTRUE_EXTERNAL_WEB3_SOLANA_ENABLED="true" +GOTRUE_RATE_LIMIT_VERIFY="100000" +GOTRUE_RATE_LIMIT_TOKEN_REFRESH="30" +GOTRUE_RATE_LIMIT_ANONYMOUS_USERS="5" +GOTRUE_RATE_LIMIT_HEADER="My-Custom-Header" +GOTRUE_TRACING_ENABLED=true +GOTRUE_TRACING_EXPORTER=default GOTRUE_TRACING_HOST=127.0.0.1 GOTRUE_TRACING_PORT=8126 GOTRUE_TRACING_TAGS="env:test" GOTRUE_SECURITY_CAPTCHA_ENABLED="false" GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" +GOTRUE_SAML_ENABLED="true" +GOTRUE_SAML_PRIVATE_KEY="MIIEowIBAAKCAQEAszrVveMQcSsa0Y+zN1ZFb19cRS0jn4UgIHTprW2tVBmO2PABzjY3XFCfx6vPirMAPWBYpsKmXrvm1tr0A6DZYmA8YmJd937VUQ67fa6DMyppBYTjNgGEkEhmKuszvF3MARsIKCGtZqUrmS7UG4404wYxVppnr2EYm3RGtHlkYsXu20MBqSDXP47bQP+PkJqC3BuNGk3xt5UHl2FSFpTHelkI6lBynw16B+lUT1F96SERNDaMqi/TRsZdGe5mB/29ngC/QBMpEbRBLNRir5iUevKS7Pn4aph9Qjaxx/97siktK210FJT23KjHpgcUfjoQ6BgPBTLtEeQdRyDuc/CgfwIDAQABAoIBAGYDWOEpupQPSsZ4mjMnAYJwrp4ZISuMpEqVAORbhspVeb70bLKonT4IDcmiexCg7cQBcLQKGpPVM4CbQ0RFazXZPMVq470ZDeWDEyhoCfk3bGtdxc1Zc9CDxNMs6FeQs6r1beEZug6weG5J/yRn/qYxQife3qEuDMl+lzfl2EN3HYVOSnBmdt50dxRuX26iW3nqqbMRqYn9OHuJ1LvRRfYeyVKqgC5vgt/6Tf7DAJwGe0dD7q08byHV8DBZ0pnMVU0bYpf1GTgMibgjnLjK//EVWafFHtN+RXcjzGmyJrk3+7ZyPUpzpDjO21kpzUQLrpEkkBRnmg6bwHnSrBr8avECgYEA3pq1PTCAOuLQoIm1CWR9/dhkbJQiKTJevlWV8slXQLR50P0WvI2RdFuSxlWmA4xZej8s4e7iD3MYye6SBsQHygOVGc4efvvEZV8/XTlDdyj7iLVGhnEmu2r7AFKzy8cOvXx0QcLg+zNd7vxZv/8D3Qj9Jje2LjLHKM5n/dZ3RzUCgYEAzh5Lo2anc4WN8faLGt7rPkGQF+7/18ImQE11joHWa3LzAEy7FbeOGpE/vhOv5umq5M/KlWFIRahMEQv4RusieHWI19ZLIP+JwQFxWxS+cPp3xOiGcquSAZnlyVSxZ//dlVgaZq2o2MfrxECcovRlaknl2csyf+HjFFwKlNxHm2MCgYAr//R3BdEy0oZeVRndo2lr9YvUEmu2LOihQpWDCd0fQw0ZDA2kc28eysL2RROte95r1XTvq6IvX5a0w11FzRWlDpQ4J4/LlcQ6LVt+98SoFwew+/PWuyLmxLycUbyMOOpm9eSc4wJJZNvaUzMCSkvfMtmm5jgyZYMMQ9A2Ul/9SQKBgB9mfh9mhBwVPIqgBJETZMMXOdxrjI5SBYHGSyJqpT+5Q0vIZLfqPrvNZOiQFzwWXPJ+tV4Mc/YorW3rZOdo6tdvEGnRO6DLTTEaByrY/io3/gcBZXoSqSuVRmxleqFdWWRnB56c1hwwWLqNHU+1671FhL6pNghFYVK4suP6qu4BAoGBAMk+VipXcIlD67mfGrET/xDqiWWBZtgTzTMjTpODhDY1GZck1eb4CQMP5j5V3gFJ4cSgWDJvnWg8rcz0unz/q4aeMGl1rah5WNDWj1QKWMS6vJhMHM/rqN1WHWR0ZnV83svYgtg0zDnQKlLujqW4JmGXLMU7ur6a+e6lpa1fvLsP" +GOTRUE_MAX_VERIFIED_FACTORS=10 +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="" +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPT=true +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPTION_KEY_ID=abc +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPTION_KEY=pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4 +GOTRUE_SECURITY_DB_ENCRYPTION_DECRYPTION_KEYS=abc:pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4 diff --git a/init_postgres.sh b/init_postgres.sh new file mode 100755 index 000000000..134e17901 --- /dev/null +++ b/init_postgres.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER supabase_admin LOGIN CREATEROLE CREATEDB REPLICATION BYPASSRLS; + + -- Supabase super admin + CREATE USER supabase_auth_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION PASSWORD 'root'; + CREATE SCHEMA IF NOT EXISTS $DB_NAMESPACE AUTHORIZATION supabase_auth_admin; + GRANT CREATE ON DATABASE postgres TO supabase_auth_admin; + ALTER USER supabase_auth_admin SET search_path = '$DB_NAMESPACE'; +EOSQL diff --git a/internal/api/admin.go b/internal/api/admin.go new file mode 100644 index 000000000..9d3be8c77 --- /dev/null +++ b/internal/api/admin.go @@ -0,0 +1,643 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/fatih/structs" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/crypto/bcrypt" +) + +type AdminUserParams struct { + Id string `json:"id"` + Aud string `json:"aud"` + Role string `json:"role"` + Email string `json:"email"` + Phone string `json:"phone"` + Password *string `json:"password"` + PasswordHash string `json:"password_hash"` + EmailConfirm bool `json:"email_confirm"` + PhoneConfirm bool `json:"phone_confirm"` + UserMetaData map[string]interface{} `json:"user_metadata"` + AppMetaData map[string]interface{} `json:"app_metadata"` + BanDuration string `json:"ban_duration"` +} + +type adminUserDeleteParams struct { + ShouldSoftDelete bool `json:"should_soft_delete"` +} + +type adminUserUpdateFactorParams struct { + FriendlyName string `json:"friendly_name"` + Phone string `json:"phone"` +} + +type AdminListUsersResponse struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` +} + +func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID") + } + + observability.LogEntrySetField(r, "user_id", userID) + + u, err := models.FindUserByID(db, userID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + + return withUser(ctx, u), nil +} + +// Use only after requireAuthentication, so that there is a valid user +func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + user := getUser(ctx) + factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) + if err != nil { + return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID") + } + + observability.LogEntrySetField(r, "factor_id", factorID) + + factor, err := user.FindOwnedFactorByID(db, factorID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") + } + return nil, internalServerError("Database error loading factor").WithInternalError(err) + } + return withFactor(ctx, factor), nil +} + +func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { + params := &AdminUserParams{} + if err := retrieveRequestParams(r, params); err != nil { + return nil, err + } + + return params, nil +} + +// adminUsers responds with a list of all users in a given audience +func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + aud := a.requestAud(ctx, r) + + pageParams, err := paginate(r) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) + } + + sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) + } + + filter := r.URL.Query().Get("filter") + + users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter) + if err != nil { + return internalServerError("Database error finding users").WithInternalError(err) + } + addPaginationHeaders(w, r, pageParams) + + return sendJSON(w, http.StatusOK, AdminListUsersResponse{ + Users: users, + Aud: aud, + }) +} + +// adminUserGet returns information about a single user +func (a *API) adminUserGet(w http.ResponseWriter, r *http.Request) error { + user := getUser(r.Context()) + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserUpdate updates a single user object +func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + user := getUser(ctx) + adminUser := getAdminUser(ctx) + params, err := a.getAdminParams(r) + if err != nil { + return err + } + + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + } + + if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + } + + var banDuration *time.Duration + if params.BanDuration != "" { + duration := time.Duration(0) + if params.BanDuration != "none" { + duration, err = time.ParseDuration(params.BanDuration) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + } + } + banDuration = &duration + } + + if params.Password != nil { + password := *params.Password + + if err := a.checkPasswordStrength(ctx, password); err != nil { + return err + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if params.Role != "" { + if terr := user.SetRole(tx, params.Role); terr != nil { + return terr + } + } + + if params.EmailConfirm { + if terr := user.Confirm(tx); terr != nil { + return terr + } + } + + if params.PhoneConfirm { + if terr := user.ConfirmPhone(tx); terr != nil { + return terr + } + } + + if params.Password != nil { + if terr := user.UpdatePassword(tx, nil); terr != nil { + return terr + } + } + + var identities []models.Identity + if params.Email != "" { + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "email"); terr != nil && !models.IsNotFoundError(terr) { + return terr + } else if identity == nil { + // if the user doesn't have an existing email + // then updating the user's email should create a new email identity + i, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: params.Email, + EmailVerified: params.EmailConfirm, + })) + if terr != nil { + return terr + } + identities = append(identities, *i) + } else { + // update the existing email identity + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "email": params.Email, + "email_verified": params.EmailConfirm, + }); terr != nil { + return terr + } + } + if user.IsAnonymous && params.EmailConfirm { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + + if terr := user.SetEmail(tx, params.Email); terr != nil { + return terr + } + } + + if params.Phone != "" { + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "phone"); terr != nil && !models.IsNotFoundError(terr) { + return terr + } else if identity == nil { + // if the user doesn't have an existing phone + // then updating the user's phone should create a new phone identity + identity, terr := a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: params.Phone, + PhoneVerified: params.PhoneConfirm, + })) + if terr != nil { + return terr + } + identities = append(identities, *identity) + } else { + // update the existing phone identity + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "phone": params.Phone, + "phone_verified": params.PhoneConfirm, + }); terr != nil { + return terr + } + } + if user.IsAnonymous && params.PhoneConfirm { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + if terr := user.SetPhone(tx, params.Phone); terr != nil { + return terr + } + } + user.Identities = append(user.Identities, identities...) + + if params.AppMetaData != nil { + if terr := user.UpdateAppMetaData(tx, params.AppMetaData); terr != nil { + return terr + } + } + + if params.UserMetaData != nil { + if terr := user.UpdateUserMetaData(tx, params.UserMetaData); terr != nil { + return terr + } + } + + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { + return terr + } + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserModifiedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return terr + } + return nil + }) + + if err != nil { + return internalServerError("Error updating user").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserCreate creates a new user based on the provided data +func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + adminUser := getAdminUser(ctx) + params, err := a.getAdminParams(r) + if err != nil { + return err + } + + aud := a.requestAud(ctx, r) + if params.Aud != "" { + aud = params.Aud + } + + if params.Email == "" && params.Phone == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") + } + + var providers []string + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { + return internalServerError("Database error checking email").WithInternalError(err) + } else if user != nil { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + providers = append(providers, "email") + } + + if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { + return internalServerError("Database error checking phone").WithInternalError(err) + } else if exists { + return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user") + } + providers = append(providers, "phone") + } + + if params.Password != nil && params.PasswordHash != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided") + } + + if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" { + password, err := password.Generate(64, 10, 0, false, true) + if err != nil { + return internalServerError("Error generating password").WithInternalError(err) + } + params.Password = &password + } + + var user *models.User + if params.PasswordHash != "" { + user, err = models.NewUserWithPasswordHash(params.Phone, params.Email, params.PasswordHash, aud, params.UserMetaData) + } else { + user, err = models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData) + } + + if err != nil { + if errors.Is(err, bcrypt.ErrPasswordTooLong) { + return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + } + return internalServerError("Error creating user").WithInternalError(err) + } + + if params.Id != "" { + customId, err := uuid.FromString(params.Id) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format") + } + if customId == uuid.Nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid") + } + user.ID = customId + } + + user.AppMetaData = map[string]interface{}{ + // TODO: Deprecate "provider" field + // default to the first provider in the providers slice + "provider": providers[0], + "providers": providers, + } + + var banDuration *time.Duration + if params.BanDuration != "" { + duration := time.Duration(0) + if params.BanDuration != "none" { + duration, err = time.ParseDuration(params.BanDuration) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + } + } + banDuration = &duration + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(user); terr != nil { + return terr + } + + var identities []models.Identity + if user.GetEmail() != "" { + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + + if terr != nil { + return terr + } + identities = append(identities, *identity) + } + + if user.GetPhone() != "" { + identity, terr := a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: user.GetPhone(), + })) + + if terr != nil { + return terr + } + identities = append(identities, *identity) + } + + user.Identities = identities + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserSignedUpAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return terr + } + + role := config.JWT.DefaultGroupName + if params.Role != "" { + role = params.Role + } + if terr := user.SetRole(tx, role); terr != nil { + return terr + } + + if params.AppMetaData != nil { + if terr := user.UpdateAppMetaData(tx, params.AppMetaData); terr != nil { + return terr + } + } + + if params.EmailConfirm { + if terr := user.Confirm(tx); terr != nil { + return terr + } + } + + if params.PhoneConfirm { + if terr := user.ConfirmPhone(tx); terr != nil { + return terr + } + } + + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { + return terr + } + } + + return nil + }) + + if err != nil { + return internalServerError("Database error creating new user").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserDelete deletes a user +func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + adminUser := getAdminUser(ctx) + + // ShouldSoftDelete defaults to false + params := &adminUserDeleteParams{} + if body, _ := utilities.GetBodyBytes(r); len(body) != 0 { + // we only want to parse the body if it's not empty + // retrieveRequestParams will handle any errors with stream + if err := retrieveRequestParams(r, params); err != nil { + return err + } + } + + err := a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + + if params.ShouldSoftDelete { + if user.DeletedAt != nil { + // user has been soft deleted already + return nil + } + if terr := user.SoftDeleteUser(tx); terr != nil { + return internalServerError("Error soft deleting user").WithInternalError(terr) + } + + if terr := user.SoftDeleteUserIdentities(tx); terr != nil { + return internalServerError("Error soft deleting user identities").WithInternalError(terr) + } + + // hard delete all associated factors + if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil { + return internalServerError("Error deleting user's factors").WithInternalError(terr) + } + // hard delete all associated sessions + if terr := models.Logout(tx, user.ID); terr != nil { + return internalServerError("Error deleting user's sessions").WithInternalError(terr) + } + } else { + if terr := tx.Destroy(user); terr != nil { + return internalServerError("Database error deleting user").WithInternalError(terr) + } + } + + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{}) +} + +func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + + err := a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{ + "user_id": user.ID, + "factor_id": factor.ID, + }); terr != nil { + return terr + } + if terr := tx.Destroy(factor); terr != nil { + return internalServerError("Database error deleting factor").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, factor) +} + +func (a *API) adminUserGetFactors(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + return sendJSON(w, http.StatusOK, user.Factors) +} + +// adminUserUpdate updates a single factor object +func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + factor := getFactor(ctx) + user := getUser(ctx) + adminUser := getAdminUser(ctx) + params := &adminUserUpdateFactorParams{} + + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + err := a.db.Transaction(func(tx *storage.Connection) error { + if params.FriendlyName != "" { + if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil { + return terr + } + } + + if params.Phone != "" && factor.IsPhoneFactor() { + phone, err := validatePhone(params.Phone) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + if terr := factor.UpdatePhone(tx, phone); terr != nil { + return terr + } + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UpdateFactorAction, "", map[string]interface{}{ + "user_id": user.ID, + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, factor) +} diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go new file mode 100644 index 000000000..c72694f78 --- /dev/null +++ b/internal/api/admin_test.go @@ -0,0 +1,916 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type AdminTestSuite struct { + suite.Suite + User *models.User + API *API + Config *conf.GlobalConfiguration + + token string +} + +func TestAdmin(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AdminTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *AdminTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + ts.Config.External.Email.Enabled = true + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + ts.token = token +} + +// TestAdminUsersUnauthorized tests API /admin/users route without authentication +func (ts *AdminTestSuite) TestAdminUsersUnauthorized() { + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers() { + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "0", w.Header().Get("X-Total-Count")) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_Pagination() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("987654321", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?per_page=1", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.Equal(ts.T(), "; rel=\"next\", ; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "2", w.Header().Get("X-Total-Count")) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + for _, user := range data["users"].([]interface{}) { + assert.NotEmpty(ts.T(), user) + } +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_SortAsc() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now().Add(-time.Minute) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now() + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + qv := req.URL.Query() + qv.Set("sort", "created_at asc") + req.URL.RawQuery = qv.Encode() + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 2) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) + assert.Equal(ts.T(), "test2@example.com", data.Users[1].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_SortDesc() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now().Add(-time.Minute) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("987654321", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now() + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 2) + assert.Equal(ts.T(), "test2@example.com", data.Users[0].GetEmail()) + assert.Equal(ts.T(), "test1@example.com", data.Users[1].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_FilterEmail() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=test1", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 1) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_FilterName() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=User", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 1) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) +} + +// TestAdminUserCreate tests API /admin/user route (POST) +func (ts *AdminTestSuite) TestAdminUserCreate() { + cases := []struct { + desc string + params map[string]interface{} + expected map[string]interface{} + }{ + { + desc: "Only phone", + params: map[string]interface{}{ + "phone": "123456789", + "password": "test1", + }, + expected: map[string]interface{}{ + "email": "", + "phone": "123456789", + "isAuthenticated": true, + "provider": "phone", + "providers": []string{"phone"}, + "password": "test1", + }, + }, + { + desc: "With password", + params: map[string]interface{}{ + "email": "test1@example.com", + "phone": "123456789", + "password": "test1", + }, + expected: map[string]interface{}{ + "email": "test1@example.com", + "phone": "123456789", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email", "phone"}, + "password": "test1", + }, + }, + { + desc: "Without password", + params: map[string]interface{}{ + "email": "test2@example.com", + "phone": "", + }, + expected: map[string]interface{}{ + "email": "test2@example.com", + "phone": "", + "isAuthenticated": false, + "provider": "email", + "providers": []string{"email"}, + }, + }, + { + desc: "With empty string password", + params: map[string]interface{}{ + "email": "test3@example.com", + "phone": "", + "password": "", + }, + expected: map[string]interface{}{ + "email": "test3@example.com", + "phone": "", + "isAuthenticated": false, + "provider": "email", + "providers": []string{"email"}, + "password": "", + }, + }, + { + desc: "Ban created user", + params: map[string]interface{}{ + "email": "test4@example.com", + "phone": "", + "password": "test1", + "ban_duration": "24h", + }, + expected: map[string]interface{}{ + "email": "test4@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test1", + }, + }, + { + desc: "With password hash", + params: map[string]interface{}{ + "email": "test5@example.com", + "password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + expected: map[string]interface{}{ + "email": "test5@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", + }, + }, + { + desc: "With custom id", + params: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "password": "test", + }, + expected: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + ts.Config.External.Phone.Enabled = true + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.Equal(ts.T(), c.expected["email"], data.GetEmail()) + assert.Equal(ts.T(), c.expected["phone"], data.GetPhone()) + assert.Equal(ts.T(), c.expected["provider"], data.AppMetaData["provider"]) + assert.ElementsMatch(ts.T(), c.expected["providers"], data.AppMetaData["providers"]) + + u, err := models.FindUserByID(ts.API.db, data.ID) + require.NoError(ts.T(), err) + + // verify that the corresponding identities were created + require.NotEmpty(ts.T(), u.Identities) + for _, identity := range u.Identities { + require.Equal(ts.T(), u.ID, identity.UserID) + if identity.Provider == "email" { + require.Equal(ts.T(), c.expected["email"], identity.IdentityData["email"]) + } + if identity.Provider == "phone" { + require.Equal(ts.T(), c.expected["phone"], identity.IdentityData["phone"]) + } + } + + if _, ok := c.expected["password"]; ok { + expectedPassword := fmt.Sprintf("%v", c.expected["password"]) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated) + } + + if id, ok := c.expected["id"]; ok { + uid, err := uuid.FromString(id.(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), uid, data.ID) + } + + // remove created user after each case + require.NoError(ts.T(), ts.API.db.Destroy(u)) + }) + } +} + +// TestAdminUserGet tests API /admin/user route (GET) +func (ts *AdminTestSuite) TestAdminUserGet() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test Get User"}) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/users/%s", u.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + assert.Equal(ts.T(), data["email"], "test1@example.com") + assert.NotNil(ts.T(), data["app_metadata"]) + assert.NotNil(ts.T(), data["user_metadata"]) + md := data["user_metadata"].(map[string]interface{}) + assert.Len(ts.T(), md, 1) + assert.Equal(ts.T(), "Test Get User", md["full_name"]) +} + +// TestAdminUserUpdate tests API /admin/user route (UPDATE) +func (ts *AdminTestSuite) TestAdminUserUpdate() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var buffer bytes.Buffer + newEmail := "test2@example.com" + newPhone := "234567890" + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "role": "testing", + "app_metadata": map[string]interface{}{ + "roles": []string{"writer", "editor"}, + }, + "user_metadata": map[string]interface{}{ + "name": "David", + }, + "ban_duration": "24h", + "email": newEmail, + "phone": newPhone, + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + assert.Equal(ts.T(), "testing", data.Role) + assert.NotNil(ts.T(), data.UserMetaData) + assert.Equal(ts.T(), "David", data.UserMetaData["name"]) + assert.Equal(ts.T(), newEmail, data.GetEmail()) + assert.Equal(ts.T(), newPhone, data.GetPhone()) + + assert.NotNil(ts.T(), data.AppMetaData) + assert.Len(ts.T(), data.AppMetaData["roles"], 2) + assert.Contains(ts.T(), data.AppMetaData["roles"], "writer") + assert.Contains(ts.T(), data.AppMetaData["roles"], "editor") + assert.NotNil(ts.T(), data.BannedUntil) + + u, err = models.FindUserByID(ts.API.db, data.ID) + require.NoError(ts.T(), err) + + // check if the corresponding identities were successfully created + require.NotEmpty(ts.T(), u.Identities) + + for _, identity := range u.Identities { + // for email & phone identities, the providerId is the same as the userId + require.Equal(ts.T(), u.ID.String(), identity.ProviderID) + require.Equal(ts.T(), u.ID, identity.UserID) + if identity.Provider == "email" { + require.Equal(ts.T(), newEmail, identity.IdentityData["email"]) + } + if identity.Provider == "phone" { + require.Equal(ts.T(), newPhone, identity.IdentityData["phone"]) + + } + } +} + +func (ts *AdminTestSuite) TestAdminUserUpdatePasswordFailed() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) + ts.Config.Password.MinLength = 6 + ts.Run("Password doesn't meet minimum length", func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "", + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) + }) +} + +func (ts *AdminTestSuite) TestAdminUserUpdateBannedUntilFailed() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) + ts.Config.Password.MinLength = 6 + ts.Run("Incorrect format for ban_duration", func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "ban_duration": "24", + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + }) +} + +// TestAdminUserDelete tests API /admin/users route (DELETE) +func (ts *AdminTestSuite) TestAdminUserDelete() { + type expected struct { + code int + err error + } + signupParams := &SignupParams{ + Email: "test-delete@example.com", + Password: "test", + Data: map[string]interface{}{"name": "test"}, + Provider: "email", + Aud: ts.Config.JWT.Aud, + } + cases := []struct { + desc string + body map[string]interface{} + isSoftDelete string + isSSOUser bool + expected expected + }{ + { + desc: "Test admin delete user (default)", + isSoftDelete: "", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: nil, + }, + { + desc: "Test admin delete user (hard deletion)", + isSoftDelete: "?is_soft_delete=false", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: map[string]interface{}{ + "should_soft_delete": false, + }, + }, + { + desc: "Test admin delete user (soft deletion)", + isSoftDelete: "?is_soft_delete=true", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: map[string]interface{}{ + "should_soft_delete": true, + }, + }, + { + desc: "Test admin delete user (soft deletion & sso user)", + isSoftDelete: "?is_soft_delete=true", + isSSOUser: true, + expected: expected{code: http.StatusOK, err: nil}, + body: map[string]interface{}{ + "should_soft_delete": true, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + u, err := signupParams.ToUserModel(false /* <- isSSOUser */) + require.NoError(ts.T(), err) + u, err = ts.API.signupNewUser(ts.API.db, u) + require.NoError(ts.T(), err) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + if c.isSSOUser { + u, err = models.FindUserByID(ts.API.db, u.ID) + require.NotNil(ts.T(), u) + } else { + _, err = models.FindUserByEmailAndAudience(ts.API.db, signupParams.Email, ts.Config.JWT.Aud) + } + require.Equal(ts.T(), c.expected.err, err) + }) + } +} + +func (ts *AdminTestSuite) TestAdminUserSoftDeletion() { + // create user + u, err := models.NewUser("123456789", "test@example.com", "secret", ts.Config.JWT.Aud, map[string]interface{}{"name": "test"}) + require.NoError(ts.T(), err) + u.ConfirmationToken = "some_token" + u.RecoveryToken = "some_token" + u.EmailChangeTokenCurrent = "some_token" + u.EmailChangeTokenNew = "some_token" + u.PhoneChangeToken = "some_token" + u.AppMetaData = map[string]interface{}{ + "provider": "email", + } + require.NoError(ts.T(), ts.API.db.Create(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetPhone(), u.PhoneChangeToken, models.PhoneChangeToken)) + + // create user identities + _, err = ts.API.createNewIdentity(ts.API.db, u, "email", map[string]interface{}{ + "sub": "123456", + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + _, err = ts.API.createNewIdentity(ts.API.db, u, "github", map[string]interface{}{ + "sub": "234567", + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "should_soft_delete": true, + })) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // get soft-deleted user from db + deletedUser, err := models.FindUserByID(ts.API.db, u.ID) + require.NoError(ts.T(), err) + + require.Empty(ts.T(), deletedUser.ConfirmationToken) + require.Empty(ts.T(), deletedUser.RecoveryToken) + require.Empty(ts.T(), deletedUser.EmailChangeTokenCurrent) + require.Empty(ts.T(), deletedUser.EmailChangeTokenNew) + require.Empty(ts.T(), deletedUser.EncryptedPassword) + require.Empty(ts.T(), deletedUser.PhoneChangeToken) + require.Empty(ts.T(), deletedUser.UserMetaData) + require.Empty(ts.T(), deletedUser.AppMetaData) + require.NotEmpty(ts.T(), deletedUser.DeletedAt) + require.NotEmpty(ts.T(), deletedUser.GetEmail()) + + // get soft-deleted user's identity from db + deletedIdentities, err := models.FindIdentitiesByUserID(ts.API.db, deletedUser.ID) + require.NoError(ts.T(), err) + + for _, identity := range deletedIdentities { + require.Empty(ts.T(), identity.IdentityData) + } +} + +func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { + var cases = []struct { + desc string + customConfig *conf.GlobalConfiguration + userData map[string]interface{} + expected int + }{ + { + desc: "Email Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + External: conf.ProviderConfiguration{ + Email: conf.EmailProviderConfiguration{ + Enabled: false, + }, + }, + }, + userData: map[string]interface{}{ + "email": "test1@example.com", + "password": "test1", + }, + expected: http.StatusOK, + }, + { + desc: "Phone Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + External: conf.ProviderConfiguration{ + Phone: conf.PhoneProviderConfiguration{ + Enabled: false, + }, + }, + }, + userData: map[string]interface{}{ + "phone": "123456789", + "password": "test1", + }, + expected: http.StatusOK, + }, + { + desc: "All Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + DisableSignup: true, + }, + userData: map[string]interface{}{ + "email": "test2@example.com", + "password": "test2", + }, + expected: http.StatusOK, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Initialize user data + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.userData)) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.Config.JWT = c.customConfig.JWT + ts.Config.External = c.customConfig.External + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected, w.Code) + }) + } +} + +// TestAdminUserDeleteFactor tests API /admin/users//factors// +func (ts *AdminTestSuite) TestAdminUserDeleteFactor() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewTOTPFactor(u, "testSimpleName") + require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified)) + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s/factors/%s/", u.ID, f.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + _, err = models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + +} + +// TestAdminUserGetFactor tests API /admin/user//factors/ +func (ts *AdminTestSuite) TestAdminUserGetFactors() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewTOTPFactor(u, "testSimpleName") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/users/%s/factors/", u.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + getFactorsResp := []*models.Factor{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getFactorsResp)) + require.Equal(ts.T(), getFactorsResp[0].Secret, "") +} + +func (ts *AdminTestSuite) TestAdminUserUpdateFactor() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewPhoneFactor(u, "123456789", "testSimpleName") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + var cases = []struct { + Desc string + FactorData map[string]interface{} + ExpectedCode int + }{ + { + Desc: "Update Factor friendly name", + FactorData: map[string]interface{}{ + "friendly_name": "john", + }, + ExpectedCode: http.StatusOK, + }, + { + Desc: "Update Factor phone number", + FactorData: map[string]interface{}{ + "phone": "+1976154321", + }, + ExpectedCode: http.StatusOK, + }, + } + + // Initialize factor data + for _, c := range cases { + ts.Run(c.Desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.FactorData)) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s/factors/%s/", u.ID, f.ID), &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.ExpectedCode, w.Code) + }) + } +} + +func (ts *AdminTestSuite) TestAdminUserCreateValidationErrors() { + cases := []struct { + desc string + params map[string]interface{} + }{ + { + desc: "create user without email and phone", + params: map[string]interface{}{ + "password": "test_password", + }, + }, + { + desc: "create user with password and password hash", + params: map[string]interface{}{ + "email": "test@example.com", + "password": "test_password", + "password_hash": "$2y$10$Tk6yEdmTbb/eQ/haDMaCsuCsmtPVprjHMcij1RqiJdLGPDXnL3L1a", + }, + }, + { + desc: "invalid ban duration", + params: map[string]interface{}{ + "email": "test@example.com", + "ban_duration": "never", + }, + }, + { + desc: "custom id is nil", + params: map[string]interface{}{ + "id": "00000000-0000-0000-0000-000000000000", + "email": "test@example.com", + }, + }, + { + desc: "bad id format", + params: map[string]interface{}{ + "id": "bad_uuid_format", + "email": "test@example.com", + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code, w) + + data := map[string]interface{}{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), data["error_code"], apierrors.ErrorCodeValidationFailed) + }) + + } +} diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go new file mode 100644 index 000000000..b3bc61b92 --- /dev/null +++ b/internal/api/anonymous.go @@ -0,0 +1,56 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + aud := a.requestAud(ctx, r) + + if config.DisableSignup { + return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + params.Aud = aud + params.Provider = "anonymous" + + newUser, err := params.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + + var grantParams models.GrantParams + grantParams.FillGrantParams(r) + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + newUser, terr = a.signupNewUser(tx, newUser) + if terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, newUser, models.Anonymous, grantParams) + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return internalServerError("Database error creating anonymous user").WithInternalError(err) + } + + metering.RecordLogin("anonymous", newUser.ID) + return sendJSON(w, http.StatusOK, token) +} diff --git a/internal/api/anonymous_test.go b/internal/api/anonymous_test.go new file mode 100644 index 000000000..81d900de8 --- /dev/null +++ b/internal/api/anonymous_test.go @@ -0,0 +1,329 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type AnonymousTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestAnonymous(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AnonymousTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *AnonymousTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create anonymous user + params := &SignupParams{ + Aud: ts.Config.JWT.Aud, + Provider: "anonymous", + } + u, err := params.ToUserModel(false) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new anonymous test user") +} + +func (ts *AnonymousTestSuite) TestAnonymousLogins() { + ts.Config.External.AnonymousUsers.Enabled = true + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "data": map[string]interface{}{ + "field": "foo", + }, + })) + + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.NotEmpty(ts.T(), data.User.ID) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.User.Aud) + assert.Empty(ts.T(), data.User.GetEmail()) + assert.Empty(ts.T(), data.User.GetPhone()) + assert.True(ts.T(), data.User.IsAnonymous) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"field": "foo"}), data.User.UserMetaData) +} + +func (ts *AnonymousTestSuite) TestConvertAnonymousUserToPermanent() { + ts.Config.External.AnonymousUsers.Enabled = true + ts.Config.Sms.TestOTP = map[string]string{"1234567890": "000000", "1234560000": "000000"} + // test OTPs still require setting up an sms provider + ts.Config.Sms.Provider = "twilio" + ts.Config.Sms.Twilio.AccountSid = "fake-sid" + ts.Config.Sms.Twilio.AuthToken = "fake-token" + ts.Config.Sms.Twilio.MessageServiceSid = "fake-message-service-sid" + + cases := []struct { + desc string + body map[string]interface{} + verificationType string + }{ + { + desc: "convert anonymous user to permanent user with email", + body: map[string]interface{}{ + "email": "test@example.com", + }, + verificationType: "email_change", + }, + { + desc: "convert anonymous user to permanent user with phone", + body: map[string]interface{}{ + "phone": "1234567890", + }, + verificationType: "phone_change", + }, + { + desc: "convert anonymous user to permanent user with email & password", + body: map[string]interface{}{ + "email": "test2@example.com", + "password": "test-password", + }, + verificationType: "email_change", + }, + { + desc: "convert anonymous user to permanent user with phone & password", + body: map[string]interface{}{ + "phone": "1234560000", + "password": "test-password", + }, + verificationType: "phone_change", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + signupResponse := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&signupResponse)) + + // Add email to anonymous user + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req = httptest.NewRequest(http.MethodPut, "/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signupResponse.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Check if anonymous user is still anonymous + user, err := models.FindUserByID(ts.API.db, signupResponse.User.ID) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + require.True(ts.T(), user.IsAnonymous) + + // Check if user has a password set + if c.body["password"] != nil { + require.True(ts.T(), user.HasPassword()) + } + + switch c.verificationType { + case mail.EmailChangeVerification: + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "token_hash": user.EmailChangeTokenNew, + "type": c.verificationType, + })) + case phoneChangeVerification: + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "phone": user.PhoneChange, + "token": "000000", + "type": c.verificationType, + })) + } + + req = httptest.NewRequest(http.MethodPost, "/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // User is a permanent user and not anonymous anymore + assert.Equal(ts.T(), signupResponse.User.ID, data.User.ID) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.User.Aud) + assert.False(ts.T(), data.User.IsAnonymous) + + // User should have an identity + assert.Len(ts.T(), data.User.Identities, 1) + + switch c.verificationType { + case mail.EmailChangeVerification: + assert.Equal(ts.T(), c.body["email"], data.User.GetEmail()) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"provider": "email", "providers": []interface{}{"email"}}), data.User.AppMetaData) + assert.NotEmpty(ts.T(), data.User.EmailConfirmedAt) + case phoneChangeVerification: + assert.Equal(ts.T(), c.body["phone"], data.User.GetPhone()) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"provider": "phone", "providers": []interface{}{"phone"}}), data.User.AppMetaData) + assert.NotEmpty(ts.T(), data.User.PhoneConfirmedAt) + } + }) + } +} + +func (ts *AnonymousTestSuite) TestRateLimitAnonymousSignups() { + var buffer bytes.Buffer + ts.Config.External.AnonymousUsers.Enabled = true + + // It rate limits after 30 requests + for i := 0; i < int(ts.Config.RateLimitAnonymousUsers); i++ { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It ignores X-Forwarded-For by default + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req.Header.Set("X-Forwarded-For", "1.1.1.1") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It doesn't rate limit a new value for the limited header + req.Header.Set("My-Custom-Header", "5.6.7.8") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *AnonymousTestSuite) TestAdminUpdateAnonymousUser() { + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + adminJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + + u1, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + u1.IsAnonymous = true + require.NoError(ts.T(), ts.API.db.Create(u1)) + + u2, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + u2.IsAnonymous = true + require.NoError(ts.T(), ts.API.db.Create(u2)) + + cases := []struct { + desc string + userId uuid.UUID + body map[string]interface{} + expected map[string]interface{} + expectedIdentities int + }{ + { + desc: "update anonymous user with email and email confirm true", + userId: u1.ID, + body: map[string]interface{}{ + "email": "foo@example.com", + "email_confirm": true, + }, + expected: map[string]interface{}{ + "email": "foo@example.com", + "is_anonymous": false, + }, + expectedIdentities: 1, + }, + { + desc: "update anonymous user with email and email confirm false", + userId: u2.ID, + body: map[string]interface{}{ + "email": "bar@example.com", + "email_confirm": false, + }, + expected: map[string]interface{}{ + "email": "bar@example.com", + "is_anonymous": true, + }, + expectedIdentities: 1, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", c.userId), &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", adminJwt)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var data models.User + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.NotNil(ts.T(), data) + require.Len(ts.T(), data.Identities, c.expectedIdentities) + + actual := map[string]interface{}{ + "email": data.GetEmail(), + "is_anonymous": data.IsAnonymous, + } + + require.Equal(ts.T(), c.expected, actual) + }) + } +} diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 000000000..d852b5c65 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,324 @@ +package api + +import ( + "net/http" + "regexp" + "time" + + "github.com/rs/cors" + "github.com/sebest/xff" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "github.com/supabase/hibp" +) + +const ( + audHeaderName = "X-JWT-AUD" + defaultVersion = "unknown version" +) + +var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`) + +// API is the main REST API +type API struct { + handler http.Handler + db *storage.Connection + config *conf.GlobalConfiguration + version string + + hibpClient *hibp.PwnedClient + + // overrideTime can be used to override the clock used by handlers. Should only be used in tests! + overrideTime func() time.Time + + limiterOpts *LimiterOptions +} + +func (a *API) Version() string { + return a.version +} + +func (a *API) Now() time.Time { + if a.overrideTime != nil { + return a.overrideTime() + } + + return time.Now() +} + +// NewAPI instantiates a new REST API +func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API { + return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...) +} + +func (a *API) deprecationNotices() { + config := a.config + + log := logrus.WithField("component", "api") + + if config.JWT.AdminGroupName != "" { + log.Warn("DEPRECATION NOTICE: GOTRUE_JWT_ADMIN_GROUP_NAME not supported by Supabase's GoTrue, will be removed soon") + } + + if config.JWT.DefaultGroupName != "" { + log.Warn("DEPRECATION NOTICE: GOTRUE_JWT_DEFAULT_GROUP_NAME not supported by Supabase's GoTrue, will be removed soon") + } +} + +// NewAPIWithVersion creates a new REST API using the specified version +func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { + api := &API{config: globalConfig, db: db, version: version} + + for _, o := range opt { + o.apply(api) + } + if api.limiterOpts == nil { + api.limiterOpts = NewLimiterOptions(globalConfig) + } + if api.config.Password.HIBP.Enabled { + httpClient := &http.Client{ + // all HIBP API requests should finish quickly to avoid + // unnecessary slowdowns + Timeout: 5 * time.Second, + } + + api.hibpClient = &hibp.PwnedClient{ + UserAgent: api.config.Password.HIBP.UserAgent, + HTTP: httpClient, + } + + if api.config.Password.HIBP.Bloom.Enabled { + cache := utilities.NewHIBPBloomCache(api.config.Password.HIBP.Bloom.Items, api.config.Password.HIBP.Bloom.FalsePositives) + api.hibpClient.Cache = cache + + logrus.Infof("Pwned passwords cache is %.2f KB", float64(cache.Cap())/(8*1024.0)) + } + } + + api.deprecationNotices() + + xffmw, _ := xff.Default() + logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig) + + r := newRouter() + r.UseBypass(observability.AddRequestID(globalConfig)) + r.UseBypass(logger) + r.UseBypass(xffmw.Handler) + r.UseBypass(recoverer) + + if globalConfig.API.MaxRequestDuration > 0 { + r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration)) + } + + // request tracing should be added only when tracing or metrics is enabled + if globalConfig.Tracing.Enabled || globalConfig.Metrics.Enabled { + r.UseBypass(observability.RequestTracing()) + } + + if globalConfig.DB.CleanupEnabled { + cleanup := models.NewCleanup(globalConfig) + r.UseBypass(api.databaseCleanup(cleanup)) + } + + r.Get("/health", api.HealthCheck) + r.Get("/.well-known/jwks.json", api.Jwks) + + r.Route("/callback", func(r *router) { + r.Use(api.isValidExternalHost) + r.Use(api.loadFlowState) + + r.Get("/", api.ExternalProviderCallback) + r.Post("/", api.ExternalProviderCallback) + }) + + r.Route("/", func(r *router) { + + r.Use(api.isValidExternalHost) + + r.Get("/settings", api.Settings) + + r.Get("/authorize", api.ExternalProviderRedirect) + + r.With(api.requireAdminCredentials).Post("/invite", api.Invite) + r.With(api.verifyCaptcha).Route("/signup", func(r *router) { + // rate limit per hour + limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns + limitSignups := api.limiterOpts.Signups + r.Post("/", func(w http.ResponseWriter, r *http.Request) error { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.Email == "" && params.Phone == "" { + if !api.config.External.AnonymousUsers.Enabled { + return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") + } + if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil { + return err + } + return api.SignupAnonymously(w, r) + } + + // apply ip-based rate limiting on otps + if _, err := api.limitHandler(limitSignups)(w, r); err != nil { + return err + } + return api.Signup(w, r) + }) + }) + r.With(api.limitHandler(api.limiterOpts.Recover)). + With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) + + r.With(api.limitHandler(api.limiterOpts.Resend)). + With(api.verifyCaptcha).Post("/resend", api.Resend) + + r.With(api.limitHandler(api.limiterOpts.MagicLink)). + With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + + r.With(api.limitHandler(api.limiterOpts.Otp)). + With(api.verifyCaptcha).Post("/otp", api.Otp) + + r.With(api.limitHandler(api.limiterOpts.Token)). + With(api.verifyCaptcha).Post("/token", api.Token) + + r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) { + r.Get("/", api.Verify) + r.Post("/", api.Verify) + }) + + r.With(api.requireAuthentication).Post("/logout", api.Logout) + + r.With(api.requireAuthentication).Route("/reauthenticate", func(r *router) { + r.Get("/", api.Reauthenticate) + }) + + r.With(api.requireAuthentication).Route("/user", func(r *router) { + r.Get("/", api.UserGet) + r.With(api.limitHandler(api.limiterOpts.User)).Put("/", api.UserUpdate) + + r.Route("/identities", func(r *router) { + r.Use(api.requireManualLinkingEnabled) + r.Get("/authorize", api.LinkIdentity) + r.Delete("/{identity_id}", api.DeleteIdentity) + }) + }) + + r.With(api.requireAuthentication).Route("/factors", func(r *router) { + r.Use(api.requireNotAnonymous) + r.Post("/", api.EnrollFactor) + r.Route("/{factor_id}", func(r *router) { + r.Use(api.loadFactor) + + r.With(api.limitHandler(api.limiterOpts.FactorVerify)). + Post("/verify", api.VerifyFactor) + r.With(api.limitHandler(api.limiterOpts.FactorChallenge)). + Post("/challenge", api.ChallengeFactor) + r.Delete("/", api.UnenrollFactor) + + }) + }) + + r.Route("/sso", func(r *router) { + r.Use(api.requireSAMLEnabled) + r.With(api.limitHandler(api.limiterOpts.SSO)). + With(api.verifyCaptcha).Post("/", api.SingleSignOn) + + r.Route("/saml", func(r *router) { + r.Get("/metadata", api.SAMLMetadata) + + r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)). + Post("/acs", api.SamlAcs) + }) + }) + + r.Route("/admin", func(r *router) { + r.Use(api.requireAdminCredentials) + + r.Route("/audit", func(r *router) { + r.Get("/", api.adminAuditLog) + }) + + r.Route("/users", func(r *router) { + r.Get("/", api.adminUsers) + r.Post("/", api.adminUserCreate) + + r.Route("/{user_id}", func(r *router) { + r.Use(api.loadUser) + r.Route("/factors", func(r *router) { + r.Get("/", api.adminUserGetFactors) + r.Route("/{factor_id}", func(r *router) { + r.Use(api.loadFactor) + r.Delete("/", api.adminUserDeleteFactor) + r.Put("/", api.adminUserUpdateFactor) + }) + }) + + r.Get("/", api.adminUserGet) + r.Put("/", api.adminUserUpdate) + r.Delete("/", api.adminUserDelete) + }) + }) + + r.Post("/generate_link", api.adminGenerateLink) + + r.Route("/sso", func(r *router) { + r.Route("/providers", func(r *router) { + r.Get("/", api.adminSSOProvidersList) + r.Post("/", api.adminSSOProvidersCreate) + + r.Route("/{idp_id}", func(r *router) { + r.Use(api.loadSSOProvider) + + r.Get("/", api.adminSSOProvidersGet) + r.Put("/", api.adminSSOProvidersUpdate) + r.Delete("/", api.adminSSOProvidersDelete) + }) + }) + }) + + }) + }) + + corsHandler := cors.New(cors.Options{ + AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}, + AllowedHeaders: globalConfig.CORS.AllAllowedHeaders([]string{"Accept", "Authorization", "Content-Type", "X-Client-IP", "X-Client-Info", audHeaderName, useCookieHeader, APIVersionHeaderName}), + ExposedHeaders: []string{"X-Total-Count", "Link", APIVersionHeaderName}, + AllowCredentials: true, + }) + + api.handler = corsHandler.Handler(r) + return api +} + +type HealthCheckResponse struct { + Version string `json:"version"` + Name string `json:"name"` + Description string `json:"description"` +} + +// HealthCheck endpoint indicates if the gotrue api service is available +func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { + return sendJSON(w, http.StatusOK, HealthCheckResponse{ + Version: a.version, + Name: "GoTrue", + Description: "GoTrue is a user registration and authentication API", + }) +} + +// Mailer returns NewMailer with the current tenant config +func (a *API) Mailer() mailer.Mailer { + config := a.config + return mailer.NewMailer(config) +} + +// ServeHTTP implements the http.Handler interface by passing the request along +// to its underlying Handler. +func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.handler.ServeHTTP(w, r) +} diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 000000000..a472be737 --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,57 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +const ( + apiTestVersion = "1" + apiTestConfig = "../../hack/test.env" +) + +func init() { + crypto.PasswordHashCost = crypto.QuickHashCost +} + +// setupAPIForTest creates a new API to run tests with. +// Using this function allows us to keep track of the database connection +// and cleaning up data between tests. +func setupAPIForTest() (*API, *conf.GlobalConfiguration, error) { + return setupAPIForTestWithCallback(nil) +} + +func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Connection)) (*API, *conf.GlobalConfiguration, error) { + config, err := conf.LoadGlobal(apiTestConfig) + if err != nil { + return nil, nil, err + } + + if cb != nil { + cb(config, nil) + } + + conn, err := test.SetupDBConnection(config) + if err != nil { + return nil, nil, err + } + + if cb != nil { + cb(nil, conn) + } + + limiterOpts := NewLimiterOptions(config) + return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil +} + +func TestEmailEnabledByDefault(t *testing.T) { + api, _, err := setupAPIForTest() + require.NoError(t, err) + + require.True(t, api.config.External.Email.Enabled) +} diff --git a/internal/api/apierrors/apierrors.go b/internal/api/apierrors/apierrors.go new file mode 100644 index 000000000..7c6780d30 --- /dev/null +++ b/internal/api/apierrors/apierrors.go @@ -0,0 +1,93 @@ +package apierrors + +import ( + "fmt" +) + +// OAuthError is the JSON handler for OAuth2 error responses +type OAuthError struct { + Err string `json:"error"` + Description string `json:"error_description,omitempty"` + InternalError error `json:"-"` + InternalMessage string `json:"-"` +} + +func NewOAuthError(err string, description string) *OAuthError { + return &OAuthError{Err: err, Description: description} +} + +func (e *OAuthError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%s: %s", e.Err, e.Description) +} + +// WithInternalError adds internal error information to the error +func (e *OAuthError) WithInternalError(err error) *OAuthError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + +// Cause returns the root cause error +func (e *OAuthError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// HTTPError is an error with a message and an HTTP status code. +type HTTPError struct { + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! + InternalError error `json:"-"` + InternalMessage string `json:"-"` + ErrorID string `json:"error_id,omitempty"` +} + +func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return &HTTPError{ + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: fmt.Sprintf(fmtString, args...), + } +} + +func (e *HTTPError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) +} + +func (e *HTTPError) Is(target error) bool { + return e.Error() == target.Error() +} + +// Cause returns the root cause error +func (e *HTTPError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// WithInternalError adds internal error information to the error +func (e *HTTPError) WithInternalError(err error) *HTTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go new file mode 100644 index 000000000..2a151924b --- /dev/null +++ b/internal/api/apierrors/errorcode.go @@ -0,0 +1,97 @@ +package apierrors + +type ErrorCode = string + +const ( + // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. + ErrorCodeUnknown ErrorCode = "unknown" + + // ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error. + ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure" + + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeSessionExpired ErrorCode = "session_expired" + ErrorCodeRefreshTokenNotFound ErrorCode = "refresh_token_not_found" + ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeSingleIdentityNotDeletable ErrorCode = "single_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" + ErrorCodeOverRequestRateLimit ErrorCode = "over_request_rate_limit" + ErrorCodeOverEmailSendRateLimit ErrorCode = "over_email_send_rate_limit" + ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" + ErrorCodeBadCodeVerifier ErrorCode = "bad_code_verifier" + ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" + ErrorCodeHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" + ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" + ErrorCodeHookPayloadInvalidContentType ErrorCode = "hook_payload_invalid_content_type" + ErrorCodeRequestTimeout ErrorCode = "request_timeout" + ErrorCodeMFAPhoneEnrollDisabled ErrorCode = "mfa_phone_enroll_not_enabled" + ErrorCodeMFAPhoneVerifyDisabled ErrorCode = "mfa_phone_verify_not_enabled" + ErrorCodeMFATOTPEnrollDisabled ErrorCode = "mfa_totp_enroll_not_enabled" + ErrorCodeMFATOTPVerifyDisabled ErrorCode = "mfa_totp_verify_not_enabled" + ErrorCodeMFAWebAuthnEnrollDisabled ErrorCode = "mfa_webauthn_enroll_not_enabled" + ErrorCodeMFAWebAuthnVerifyDisabled ErrorCode = "mfa_webauthn_verify_not_enabled" + ErrorCodeMFAVerifiedFactorExists ErrorCode = "mfa_verified_factor_exists" + //#nosec G101 -- Not a secret value. + ErrorCodeInvalidCredentials ErrorCode = "invalid_credentials" + ErrorCodeEmailAddressNotAuthorized ErrorCode = "email_address_not_authorized" + ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid" + ErrorCodeWeb3ProviderDisabled ErrorCode = "web3_provider_disabled" + ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain" +) diff --git a/internal/api/apiversions.go b/internal/api/apiversions.go new file mode 100644 index 000000000..b5394a5fc --- /dev/null +++ b/internal/api/apiversions.go @@ -0,0 +1,35 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} + +func FormatAPIVersion(apiVersion APIVersion) string { + return apiVersion.Format("2006-01-02") +} diff --git a/internal/api/apiversions_test.go b/internal/api/apiversions_test.go new file mode 100644 index 000000000..0a9622132 --- /dev/null +++ b/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/api/audit.go b/internal/api/audit.go similarity index 59% rename from api/audit.go rename to internal/api/audit.go index ea1236798..e2d71bbb4 100644 --- a/api/audit.go +++ b/internal/api/audit.go @@ -4,23 +4,24 @@ import ( "net/http" "strings" - "github.com/netlify/gotrue/models" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" ) var filterColumnMap = map[string][]string{ - "author": []string{"actor_username", "actor_name"}, - "action": []string{"action"}, - "type": []string{"log_type"}, + "author": {"actor_username", "actor_name"}, + "action": {"action"}, + "type": {"log_type"}, } func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() - instanceID := getInstanceID(ctx) - // aud := a.requestAud(ctx, r) + db := a.db.WithContext(ctx) + // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -31,12 +32,12 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } - logs, err := models.FindAuditLogEntries(a.db, instanceID, col, qval, pageParams) + logs, err := models.FindAuditLogEntries(db, col, qval, pageParams) if err != nil { return internalServerError("Error searching for audit logs").WithInternalError(err) } diff --git a/api/audit_test.go b/internal/api/audit_test.go similarity index 76% rename from api/audit_test.go rename to internal/api/audit_test.go index b6a3fdd08..c8e992ead 100644 --- a/api/audit_test.go +++ b/internal/api/audit_test.go @@ -6,34 +6,30 @@ import ( "net/http" "net/http/httptest" "testing" - "time" - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" + jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" ) type AuditTestSuite struct { suite.Suite API *API - Config *conf.Configuration + Config *conf.GlobalConfiguration - token string - instanceID uuid.UUID + token string } func TestAudit(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() + api, config, err := setupAPIForTest() require.NoError(t, err) ts := &AuditTestSuite{ - API: api, - Config: config, - instanceID: instanceID, + API: api, + Config: config, } defer api.db.Close() @@ -46,15 +42,23 @@ func (ts *AuditTestSuite) SetupTest() { } func (ts *AuditTestSuite) makeSuperAdmin(email string) string { - u, err := models.NewUser(ts.instanceID, email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + u, err := models.NewUser("", email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) require.NoError(ts.T(), err, "Error making new user") u.Role = "supabase_admin" + require.NoError(ts.T(), ts.API.db.Create(u)) - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + var token string + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.PasswordGrant) require.NoError(ts.T(), err, "Error generating access token") - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -74,8 +78,8 @@ func (ts *AuditTestSuite) TestAuditGet() { ts.API.handler.ServeHTTP(w, req) require.Equal(ts.T(), http.StatusOK, w.Code) - assert.Equal(ts.T(), "; rel=\"last\"", w.HeaderMap.Get("Link")) - assert.Equal(ts.T(), "1", w.HeaderMap.Get("X-Total-Count")) + assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "1", w.Header().Get("X-Total-Count")) logs := []models.AuditLogEntry{} require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&logs)) @@ -121,7 +125,7 @@ func (ts *AuditTestSuite) TestAuditFilters() { func (ts *AuditTestSuite) prepareDeleteEvent() { // DELETE USER - u, err := models.NewUser(ts.instanceID, "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + u, err := models.NewUser("12345678", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error making new user") require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") diff --git a/internal/api/auth.go b/internal/api/auth.go new file mode 100644 index 000000000..be2400d2f --- /dev/null +++ b/internal/api/auth.go @@ -0,0 +1,142 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// requireAuthentication checks incoming requests for tokens presented using the Authorization header +func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (context.Context, error) { + token, err := a.extractBearerToken(r) + if err != nil { + return nil, err + } + + ctx, err := a.parseJWTClaims(token, r) + if err != nil { + return ctx, err + } + + ctx, err = a.maybeLoadUserOrSession(ctx) + if err != nil { + return ctx, err + } + return ctx, err +} + +func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + claims := getClaims(ctx) + if claims.IsAnonymous { + return nil, forbiddenError(apierrors.ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") + } + return ctx, nil +} + +func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { + // Find the administrative user + claims := getClaims(ctx) + if claims == nil { + return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "Invalid token") + } + + adminRoles := a.config.JWT.AdminRoles + + if isStringInSlice(claims.Role, adminRoles) { + // successful authentication + return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil + } + + return nil, forbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) +} + +func (a *API) extractBearerToken(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + matches := bearerRegexp.FindStringSubmatch(authHeader) + if len(matches) != 2 { + return "", httpError(http.StatusUnauthorized, apierrors.ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") + } + + return matches[1], nil +} + +func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, error) { + ctx := r.Context() + config := a.config + + p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) + token, err := p.ParseWithClaims(bearer, &AccessTokenClaims{}, func(token *jwt.Token) (interface{}, error) { + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &config.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(config.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") + }) + if err != nil { + return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) + } + + return withToken(ctx, token), nil +} + +func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) { + db := a.db.WithContext(ctx) + claims := getClaims(ctx) + + if claims == nil { + return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid token: missing claims") + } + + if claims.Subject == "" { + return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim") + } + + var user *models.User + if claims.Subject != "" { + userId, err := uuid.FromString(claims.Subject) + if err != nil { + return ctx, badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) + } + user, err = models.FindUserByID(db, userId) + if err != nil { + if models.IsNotFoundError(err) { + return ctx, forbiddenError(apierrors.ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") + } + return ctx, err + } + ctx = withUser(ctx, user) + } + + var session *models.Session + if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { + sessionId, err := uuid.FromString(claims.SessionId) + if err != nil { + return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) + } + session, err = models.FindSessionByID(db, sessionId, false) + if err != nil { + if models.IsNotFoundError(err) { + return ctx, forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) + } + return ctx, err + } + ctx = withSession(ctx, session) + } + return ctx, nil +} diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go new file mode 100644 index 000000000..4d75184da --- /dev/null +++ b/internal/api/auth_test.go @@ -0,0 +1,285 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + jwk "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type AuthTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestAuth(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AuthTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *AuthTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *AuthTestSuite) TestExtractBearerToken() { + userClaims := &AccessTokenClaims{ + Role: "authenticated", + } + userJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, userClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + + token, err := ts.API.extractBearerToken(req) + require.NoError(ts.T(), err) + require.Equal(ts.T(), userJwt, token) +} + +func (ts *AuthTestSuite) TestParseJWTClaims() { + cases := []struct { + desc string + key map[string]interface{} + }{ + { + desc: "HMAC key", + key: map[string]interface{}{ + "kty": "oct", + "k": "S1LgKUjeqXDEolv9WPtjUpADVMHU_KYu8uRDrM-pDGg", + "kid": "ac50c3cc-9cf7-4fd6-a11f-fe066fd39118", + "key_ops": []string{"sign", "verify"}, + "alg": "HS256", + }, + }, + { + desc: "RSA key", + key: map[string]interface{}{ + "kty": "RSA", + "n": "2g0B_hMIx5ZPuTUtLRpRr0k314XniYm3AUFgR5FmTZIjrn7vLwsWij-2egGZeHa-y9ypAgB9Q-lQ3AlT7RMPiCIyLQI6TTC8k10NEnj8c0QZwENx1Qr8aBbuZbOP9Cz30EMWZSbzMbz7r8-3rp5wBRBtIPnLlbfZh_p0iBaJfB77-r_mvhOIFM4xS7ef3nkE96dnvbEN5a-HfjzDJIAt-LniUvzMWW2gQcmHiM4oeijE3PHesapLMt2JpsMhSRo8L7tysags9VMoyZ1GnpCdjtRwb_KpY9QTjV6lL8G5nsKFH7bhABYcpjDOvqkfT5nPXj6C7oCo6MPRirPWUTbq2w", + "e": "AQAB", + "d": "OOTj_DNjOxCRRLYHT5lqbt4f3_BkdZKlWYKBaKsbkmnrPYCJUDEIdJIjPrpkHPZ-2hp9TrRp-upJ2t_kMhujFdY2WWAXbkSlL5475vICjODcBzqR3RC8wzwYgBjWGtQQ5RpcIZCELBovYbRFLR7SA8BBeTU0VaBe9gf3l_qpbOT9QIl268uFdWndTjpehGLQRmAtR1snhvTha0b9nsBZsM_K-EfnoF7Q_lPsjwWDvIGpFXao8Ifaa_sFtQkHjHVBMW2Qgx3ZSrEva_brk7w0MNSYI7Nsmr56xFOpFRwZy0v8ZtgQZ4hXmUInRHIoQ2APeds9YmemojvJKVflt9pLIQ", + "p": "-o2hdQ5Z35cIS5APTVULj_BMoPJpgkuX-PSYC1SeBeff9K04kG5zrFMWJy_-27-ys4q754lpNwJdX2CjN1nb6qyn-uKP8B2oLayKs9ebkiOqvm3S2Xblvi_F8x6sOLba3lTYHK8G7U9aMB9U0mhAzzMFdw15XXusVFDvk-zxL28", + "q": "3sp-7HzZE_elKRmebjivcDhkXO2GrcN3EIqYbbXssHZFXJwVE9oc2CErGWa7QetOCr9C--ZuTmX0X3L--CoYr-hMB0dN8lcAhapr3aau-4i7vE3DWSUdcFSyi0BBDg8pWQWbxNyTXBuWeh1cnRBsLjCxAOVTF0y3_BnVR7mbBVU", + "dp": "DuYHGMfOrk3zz1J0pnuNIXT_iX6AqZ_HHKWmuN3CO8Wq-oimWWhH9pJGOfRPqk9-19BDFiSEniHE3ZwIeI0eV5kGsBNyzatlybl90e3bMVhvmb08EXRRevqqQaesQ_8Tiq7u3t3Fgqz6RuxGBfDvEaMOCyNA-T8WYzkg1eH8AX8", + "dq": "opOCK3CvuDJvA57-TdBvtaRxGJ78OLD6oceBlA29useTthDwEJyJj-4kVVTyMRhUyuLnLoro06zytvRjuxR9D2CkmmseJkn2x5OlQwnvhv4wgSj99H9xDBfCcntg_bFyqtO859tObVh0ZogmnTbuuoYtpEm0aLxDRmRTjxOSXEE", + "qi": "8skVE7BDASHXytKSWYbkxD0B3WpXic2rtnLgiMgasdSxul8XwcB-vjVSZprVrxkcmm6ZhszoxOlq8yylBmMvAnG_gEzTls_xapeuEXGYiGaTcpkCt1r-tBKcQkka2SayaWwAljsX4xSw-zKP2koUkEET_tIcbBOW1R4OWfRGqOI", + "kid": "0d24b26c-b3ec-4c02-acfd-d5a54d50b3a4", + "key_ops": []string{"sign", "verify"}, + "alg": "RS256", + }, + }, + { + desc: "EC key", + key: map[string]interface{}{ + "kty": "EC", + "x": "5wsOh-DrNPpm9KkuydtgGs_cv3oNvtR9OdXywt12aS4", + "y": "0y01ZbuH_VQjMEd8fcYaLdiv25EVJ5GOrb79dJJsqrM", + "crv": "P-256", + "d": "EDP4ReMMpAUcf82EF3JYvkm8C5hVAh258Rj6f3HTx7c", + "kid": "10646a77-f470-44a8-8400-2f988d9c9c1a", + "key_ops": []string{"sign", "verify"}, + "alg": "ES256", + }, + }, + { + desc: "Ed25519 key", + key: map[string]interface{}{ + "crv": "Ed25519", + "d": "jVpCLvOxatVkKe1MW9nFRn6Q8VVZPq5yziKU_Z0Yu-c", + "x": "YDkGdufJBQEPO6ylvd9IKfZlzvm9tOG5VCDpkJSSkiA", + "kty": "OKP", + "kid": "ec5e7a96-ea66-456c-826c-d8d6cb928c0f", + "key_ops": []string{"sign", "verify"}, + "alg": "EdDSA", + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + bytes, err := json.Marshal(c.key) + require.NoError(ts.T(), err) + privKey, err := jwk.ParseKey(bytes) + require.NoError(ts.T(), err) + pubKey, err := privKey.PublicKey() + require.NoError(ts.T(), err) + ts.Config.JWT.Keys = conf.JwtKeysDecoder{privKey.KeyID(): conf.JwkInfo{ + PublicKey: pubKey, + PrivateKey: privKey, + }} + ts.Config.JWT.ValidMethods = nil + require.NoError(ts.T(), ts.Config.ApplyDefaults()) + + userClaims := &AccessTokenClaims{ + Role: "authenticated", + } + + // get signing key and method from config + jwk, err := conf.GetSigningJwk(&ts.Config.JWT) + require.NoError(ts.T(), err) + signingMethod := conf.GetSigningAlg(jwk) + signingKey, err := conf.GetSigningKey(jwk) + require.NoError(ts.T(), err) + + userJwtToken := jwt.NewWithClaims(signingMethod, userClaims) + require.NoError(ts.T(), err) + userJwtToken.Header["kid"] = jwk.KeyID() + userJwt, err := userJwtToken.SignedString(signingKey) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + ctx, err := ts.API.parseJWTClaims(userJwt, req) + require.NoError(ts.T(), err) + + // check if token is stored in context + token := getToken(ctx) + require.Equal(ts.T(), userJwt, token.Raw) + }) + } +} + +func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + require.NoError(ts.T(), ts.API.db.Load(s)) + + cases := []struct { + Desc string + UserJwtClaims *AccessTokenClaims + ExpectedError error + ExpectedUser *models.User + ExpectedSession *models.Session + }{ + { + Desc: "Missing Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "", + }, + Role: "authenticated", + }, + ExpectedError: forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim"), + ExpectedUser: nil, + }, + { + Desc: "Valid Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Invalid Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "invalid-subject-claim", + }, + Role: "authenticated", + }, + ExpectedError: badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), + ExpectedUser: nil, + }, + { + Desc: "Empty Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: "", + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Invalid Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: uuid.Nil.String(), + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Valid Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: s.ID.String(), + }, + ExpectedError: nil, + ExpectedUser: u, + ExpectedSession: s, + }, + { + Desc: "Session ID doesn't exist", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: "73bf9ee0-9e8c-453b-b484-09cb93e2f341", + }, + ExpectedError: forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"), + ExpectedUser: u, + ExpectedSession: nil, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + userJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, c.UserJwtClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + + ctx, err := ts.API.parseJWTClaims(userJwt, req) + require.NoError(ts.T(), err) + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + if c.ExpectedError != nil { + require.Equal(ts.T(), c.ExpectedError.Error(), err.Error()) + } else { + require.Equal(ts.T(), c.ExpectedError, err) + } + require.Equal(ts.T(), c.ExpectedUser, getUser(ctx)) + require.Equal(ts.T(), c.ExpectedSession, getSession(ctx)) + }) + } +} diff --git a/api/context.go b/internal/api/context.go similarity index 57% rename from api/context.go rename to internal/api/context.go index f1c4a78c1..3047f3dd6 100644 --- a/api/context.go +++ b/internal/api/context.go @@ -2,11 +2,10 @@ package api import ( "context" + "net/url" - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/models" ) type contextKey string @@ -17,20 +16,21 @@ func (c contextKey) String() string { const ( tokenKey = contextKey("jwt") - requestIDKey = contextKey("request_id") - configKey = contextKey("config") inviteTokenKey = contextKey("invite_token") - instanceIDKey = contextKey("instance_id") - instanceKey = contextKey("instance") signatureKey = contextKey("signature") - netlifyIDKey = contextKey("netlify_id") externalProviderTypeKey = contextKey("external_provider_type") userKey = contextKey("user") + targetUserKey = contextKey("target_user") + factorKey = contextKey("factor") + sessionKey = contextKey("session") externalReferrerKey = contextKey("external_referrer") functionHooksKey = contextKey("function_hooks") adminUserKey = contextKey("admin_user") oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token oauthVerifierKey = contextKey("oauth_verifier") + ssoProviderKey = contextKey("sso_provider") + externalHostKey = contextKey("external_host") + flowStateKey = contextKey("flow_state_id") ) // withToken adds the JWT token to the context. @@ -48,82 +48,77 @@ func getToken(ctx context.Context) *jwt.Token { return obj.(*jwt.Token) } -func getClaims(ctx context.Context) *GoTrueClaims { +func getClaims(ctx context.Context) *AccessTokenClaims { token := getToken(ctx) if token == nil { return nil } - return token.Claims.(*GoTrueClaims) + return token.Claims.(*AccessTokenClaims) } -// withRequestID adds the provided request ID to the context. -func withRequestID(ctx context.Context, id string) context.Context { - return context.WithValue(ctx, requestIDKey, id) +// withUser adds the user to the context. +func withUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, userKey, u) } -// getRequestID reads the request ID from the context. -func getRequestID(ctx context.Context) string { - obj := ctx.Value(requestIDKey) - if obj == nil { - return "" - } - - return obj.(string) +// withTargetUser adds the target user for linking to the context. +func withTargetUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, targetUserKey, u) } -// withConfig adds the tenant configuration to the context. -func withConfig(ctx context.Context, config *conf.Configuration) context.Context { - return context.WithValue(ctx, configKey, config) +// with Factor adds the factor id to the context. +func withFactor(ctx context.Context, f *models.Factor) context.Context { + return context.WithValue(ctx, factorKey, f) } -func getConfig(ctx context.Context) *conf.Configuration { - obj := ctx.Value(configKey) +// getUser reads the user from the context. +func getUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(userKey) if obj == nil { return nil } - return obj.(*conf.Configuration) -} - -// withInstanceID adds the instance id to the context. -func withInstanceID(ctx context.Context, id uuid.UUID) context.Context { - return context.WithValue(ctx, instanceIDKey, id) + return obj.(*models.User) } -// getInstanceID reads the instance id from the context. -func getInstanceID(ctx context.Context) uuid.UUID { - obj := ctx.Value(instanceIDKey) +// getTargetUser reads the user from the context. +func getTargetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(targetUserKey) if obj == nil { - return uuid.Nil + return nil } - return obj.(uuid.UUID) -} - -// withInstance adds the instance id to the context. -func withInstance(ctx context.Context, i *models.Instance) context.Context { - return context.WithValue(ctx, instanceKey, i) + return obj.(*models.User) } -// getInstance reads the instance id from the context. -func getInstance(ctx context.Context) *models.Instance { - obj := ctx.Value(instanceKey) +// getFactor reads the factor id from the context +func getFactor(ctx context.Context) *models.Factor { + obj := ctx.Value(factorKey) if obj == nil { return nil } - return obj.(*models.Instance) + return obj.(*models.Factor) } -// withUser adds the user id to the context. -func withUser(ctx context.Context, u *models.User) context.Context { - return context.WithValue(ctx, userKey, u) +// withSession adds the session to the context. +func withSession(ctx context.Context, s *models.Session) context.Context { + return context.WithValue(ctx, sessionKey, s) } -// getUser reads the user id from the context. -func getUser(ctx context.Context) *models.User { - obj := ctx.Value(userKey) +// getSession reads the session from the context. +func getSession(ctx context.Context) *models.Session { + if ctx == nil { + return nil + } + obj := ctx.Value(sessionKey) if obj == nil { return nil } - return obj.(*models.User) + return obj.(*models.Session) } // withSignature adds the provided request ID to the context. @@ -131,35 +126,22 @@ func withSignature(ctx context.Context, id string) context.Context { return context.WithValue(ctx, signatureKey, id) } -// getSignature reads the request ID from the context. -func getSignature(ctx context.Context) string { - obj := ctx.Value(signatureKey) - if obj == nil { - return "" - } - - return obj.(string) +func withInviteToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, inviteTokenKey, token) } -// withNetlifyID adds the provided request ID to the context. -func withNetlifyID(ctx context.Context, id string) context.Context { - return context.WithValue(ctx, netlifyIDKey, id) +func withFlowStateID(ctx context.Context, FlowStateID string) context.Context { + return context.WithValue(ctx, flowStateKey, FlowStateID) } -// getNetlifyID reads the request ID from the context. -func getNetlifyID(ctx context.Context) string { - obj := ctx.Value(netlifyIDKey) +func getFlowStateID(ctx context.Context) string { + obj := ctx.Value(flowStateKey) if obj == nil { return "" } - return obj.(string) } -func withInviteToken(ctx context.Context, token string) context.Context { - return context.WithValue(ctx, inviteTokenKey, token) -} - func getInviteToken(ctx context.Context) string { obj := ctx.Value(inviteTokenKey) if obj == nil { @@ -197,21 +179,6 @@ func getExternalReferrer(ctx context.Context) string { return obj.(string) } -// withFunctionHooks adds the provided function hooks to the context. -func withFunctionHooks(ctx context.Context, hooks map[string][]string) context.Context { - return context.WithValue(ctx, functionHooksKey, hooks) -} - -// getFunctionHooks reads the request ID from the context. -func getFunctionHooks(ctx context.Context) map[string][]string { - obj := ctx.Value(functionHooksKey) - if obj == nil { - return map[string][]string{} - } - - return obj.(map[string][]string) -} - // withAdminUser adds the admin user to the context. func withAdminUser(ctx context.Context, u *models.User) context.Context { return context.WithValue(ctx, adminUserKey, u) @@ -250,3 +217,27 @@ func getOAuthVerifier(ctx context.Context) string { } return obj.(string) } + +func withSSOProvider(ctx context.Context, provider *models.SSOProvider) context.Context { + return context.WithValue(ctx, ssoProviderKey, provider) +} + +func getSSOProvider(ctx context.Context) *models.SSOProvider { + obj := ctx.Value(ssoProviderKey) + if obj == nil { + return nil + } + return obj.(*models.SSOProvider) +} + +func withExternalHost(ctx context.Context, u *url.URL) context.Context { + return context.WithValue(ctx, externalHostKey, u) +} + +func getExternalHost(ctx context.Context) *url.URL { + obj := ctx.Value(externalHostKey) + if obj == nil { + return nil + } + return obj.(*url.URL) +} diff --git a/internal/api/errors.go b/internal/api/errors.go new file mode 100644 index 000000000..2e1606a1a --- /dev/null +++ b/internal/api/errors.go @@ -0,0 +1,257 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "os" + "runtime/debug" + "time" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" +) + +// Common error messages during signup flow +var ( + DuplicateEmailMsg = "A user with this email address has already been registered" + DuplicatePhoneMsg = "A user with this phone number has already been registered" + UserExistsError error = errors.New("user already exists") +) + +const InvalidChannelError = "Invalid channel, supported values are 'sms' or 'whatsapp'. 'whatsapp' is only supported if Twilio or Twilio Verify is used as the provider." + +var oauthErrorMap = map[int]string{ + http.StatusBadRequest: "invalid_request", + http.StatusUnauthorized: "unauthorized_client", + http.StatusForbidden: "access_denied", + http.StatusInternalServerError: "server_error", + http.StatusServiceUnavailable: "temporarily_unavailable", +} + +// Type aliases while we slowly refactor api errors. +type ( + HTTPError = apierrors.HTTPError + OAuthError = apierrors.OAuthError +) + +func oauthError(err string, description string) *OAuthError { + return apierrors.NewOAuthError(err, description) +} + +func httpError(httpStatus int, errorCode apierrors.ErrorCode, fmtString string, args ...any) *HTTPError { + return apierrors.NewHTTPError(httpStatus, errorCode, fmtString, args...) +} + +func badRequestError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusBadRequest, errorCode, fmtString, args...) +} + +func internalServerError(fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusInternalServerError, apierrors.ErrorCodeUnexpectedFailure, fmtString, args...) +} + +func notFoundError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusNotFound, errorCode, fmtString, args...) +} + +func forbiddenError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusForbidden, errorCode, fmtString, args...) +} + +func unprocessableEntityError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) +} + +func tooManyRequestsError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) +} + +func conflictError(fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusConflict, apierrors.ErrorCodeConflict, fmtString, args...) +} + +// Recoverer is a middleware that recovers from panics, logs the panic (and a +// backtrace), and returns a HTTP 500 (Internal Server Error) status if +// possible. Recoverer prints a request ID if one is provided. +func recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + logEntry := observability.GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) + debug.PrintStack() + } + + se := &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + } + HandleResponseError(se, w, r) + } + }() + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// ErrorCause is an error interface that contains the method Cause() for returning root cause errors +type ErrorCause interface { + Cause() error +} + +type HTTPErrorResponse20240101 struct { + Code apierrors.ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } else if apiVersion != APIVersionInitial { + // Echo back the determined API version from the request + w.Header().Set(APIVersionHeaderName, FormatAPIVersion(apiVersion)) + } + + switch e := err.(type) { + case *WeakPasswordError: + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.Code = apierrors.ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.ErrorCode = apierrors.ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + w.Header().Set("x-sb-error-code", output.ErrorCode) + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + + case *HTTPError: + switch { + case e.HTTPStatus >= http.StatusInternalServerError: + e.ErrorID = errorID + // this will get us the stack trace too + log.WithError(e.Cause()).Error(e.Error()) + case e.HTTPStatus == http.StatusTooManyRequests: + log.WithError(e.Cause()).Warn(e.Error()) + default: + log.WithError(e.Cause()).Info(e.Error()) + } + + if e.ErrorCode != "" { + w.Header().Set("x-sb-error-code", e.ErrorCode) + } + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if resp.Code == "" { + if e.HTTPStatus == http.StatusInternalServerError { + resp.Code = apierrors.ErrorCodeUnexpectedFailure + } else { + resp.Code = apierrors.ErrorCodeUnknown + } + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } else { + if e.ErrorCode == "" { + if e.HTTPStatus == http.StatusInternalServerError { + e.ErrorCode = apierrors.ErrorCodeUnexpectedFailure + } else { + e.ErrorCode = apierrors.ErrorCodeUnknown + } + } + + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + return + } + + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + + case *OAuthError: + log.WithError(e.Cause()).Info(e.Error()) + if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + + case ErrorCause: + HandleResponseError(e.Cause(), w, r) + + default: + log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: apierrors.ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } else { + httpError := HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: apierrors.ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + } +} + +func generateFrequencyLimitErrorMessage(timeStamp *time.Time, maxFrequency time.Duration) string { + now := time.Now() + left := timeStamp.Add(maxFrequency).Sub(now) / time.Second + return fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left) +} diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go new file mode 100644 index 000000000..7afdb9ca2 --- /dev/null +++ b/internal/api/errors_test.go @@ -0,0 +1,106 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + Message: "Uncoded failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + apierrors.ErrorCodeUnknown + "\",\"message\":\"Uncoded failure\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "Unexpected failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + apierrors.ErrorCodeUnexpectedFailure + "\",\"message\":\"Unexpected failure\"}", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, example.HTTPError.HTTPStatus, rec.Code) + require.Equal(t, example.ExpectedBody, rec.Body.String()) + } +} + +func TestRecoverer(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + require.NoError(t, observability.ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + panicHandler := recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + })) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + panicHandler.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + var data HTTPError + + // panic should return an internal server error + require.NoError(t, json.NewDecoder(w.Body).Decode(&data)) + require.Equal(t, apierrors.ErrorCodeUnexpectedFailure, data.ErrorCode) + require.Equal(t, http.StatusInternalServerError, data.HTTPStatus) + require.Equal(t, "Internal Server Error", data.Message) + + // panic should log the error message internally + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "request panicked", logs["msg"]) + require.Equal(t, "test panic", logs["panic"]) + require.NotEmpty(t, logs["stack"]) +} diff --git a/internal/api/external.go b/internal/api/external.go new file mode 100644 index 000000000..f6dc2384d --- /dev/null +++ b/internal/api/external.go @@ -0,0 +1,692 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/fatih/structs" + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow +type ExternalProviderClaims struct { + AuthMicroserviceClaims + Provider string `json:"provider"` + InviteToken string `json:"invite_token,omitempty"` + Referrer string `json:"referrer,omitempty"` + FlowStateID string `json:"flow_state_id"` + LinkingTargetID string `json:"linking_target_id,omitempty"` +} + +// ExternalProviderRedirect redirects the request to the oauth provider +func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { + rurl, err := a.GetExternalProviderRedirectURL(w, r, nil) + if err != nil { + return err + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider +func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + query := r.URL.Query() + providerType := query.Get("provider") + scopes := query.Get("scopes") + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + p, err := a.Provider(ctx, providerType, scopes) + if err != nil { + return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) + } + + inviteToken := query.Get("invite_token") + if inviteToken != "" { + _, userErr := models.FindUserByConfirmationToken(db, inviteToken) + if userErr != nil { + if models.IsNotFoundError(userErr) { + return "", notFoundError(apierrors.ErrorCodeUserNotFound, "User identified by token not found") + } + return "", internalServerError("Database error finding user").WithInternalError(userErr) + } + } + + redirectURL := utilities.GetReferrer(r, config) + log := observability.GetLogEntry(r).Entry + log.WithField("provider", providerType).Info("Redirecting to external provider") + if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { + return "", err + } + flowType := getFlowFromChallenge(codeChallenge) + + flowStateID := "" + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) + if err != nil { + return "", err + } + flowStateID = flowState.ID.String() + } + + claims := ExternalProviderClaims{ + AuthMicroserviceClaims: AuthMicroserviceClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + SiteURL: config.SiteURL, + InstanceID: uuid.Nil.String(), + }, + Provider: providerType, + InviteToken: inviteToken, + Referrer: redirectURL, + FlowStateID: flowStateID, + } + + if linkingTargetUser != nil { + // this means that the user is performing manual linking + claims.LinkingTargetID = linkingTargetUser.ID.String() + } + + tokenString, err := signJwt(&config.JWT, claims) + if err != nil { + return "", internalServerError("Error creating state").WithInternalError(err) + } + + authUrlParams := make([]oauth2.AuthCodeOption, 0) + query.Del("scopes") + query.Del("provider") + query.Del("code_challenge") + query.Del("code_challenge_method") + for key := range query { + if key == "workos_provider" { + // See https://workos.com/docs/reference/sso/authorize/get + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key))) + } else { + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key))) + } + } + + authURL := p.AuthCodeURL(tokenString, authUrlParams...) + + return authURL, nil +} + +// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow +func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { + rurl := a.getExternalRedirectURL(r) + u, err := url.Parse(rurl) + if err != nil { + return err + } + redirectErrors(a.internalExternalProviderCallback, w, r, u) + return nil +} + +func (a *API) handleOAuthCallback(r *http.Request) (*OAuthProviderData, error) { + ctx := r.Context() + providerType := getExternalProviderType(ctx) + + var oAuthResponseData *OAuthProviderData + var err error + switch providerType { + case "twitter": + // future OAuth1.0 providers will use this method + oAuthResponseData, err = a.oAuth1Callback(ctx, providerType) + default: + oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType) + } + if err != nil { + return nil, err + } + return oAuthResponseData, nil +} + +func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var grantParams models.GrantParams + grantParams.FillGrantParams(r) + + providerType := getExternalProviderType(ctx) + data, err := a.handleOAuthCallback(r) + if err != nil { + return err + } + + userData := data.userData + if len(userData.Emails) <= 0 { + return internalServerError("Error getting user email from external provider") + } + userData.Metadata.EmailVerified = false + for _, email := range userData.Emails { + if email.Primary { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + break + } else { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + } + } + providerAccessToken := data.token + providerRefreshToken := data.refreshToken + + var flowState *models.FlowState + // if there's a non-empty FlowStateID we perform PKCE Flow + if flowStateID := getFlowStateID(ctx); flowStateID != "" { + flowState, err = models.FindFlowStateByID(a.db, flowStateID) + if models.IsNotFoundError(err) { + return unprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) + } + + } + + var user *models.User + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if targetUser := getTargetUser(ctx); targetUser != nil { + if user, terr = a.linkIdentityToUser(r, ctx, tx, userData, providerType); terr != nil { + return terr + } + } else if inviteToken := getInviteToken(ctx); inviteToken != "" { + if user, terr = a.processInvite(r, tx, userData, inviteToken, providerType); terr != nil { + return terr + } + } else { + if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType); terr != nil { + return terr + } + } + if flowState != nil { + // This means that the callback is using PKCE + flowState.ProviderAccessToken = providerAccessToken + flowState.ProviderRefreshToken = providerRefreshToken + flowState.UserID = &(user.ID) + issueTime := time.Now() + flowState.AuthCodeIssuedAt = &issueTime + + terr = tx.Update(flowState) + } else { + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) + } + + if terr != nil { + return oauthError("server_error", terr.Error()) + } + return nil + }) + + if err != nil { + return err + } + + rurl := a.getExternalRedirectURL(r) + if flowState != nil { + // This means that the callback is using PKCE + // Set the flowState.AuthCode to the query param here + rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode) + if err != nil { + return err + } + } else if token != nil { + q := url.Values{} + q.Set("provider_token", providerAccessToken) + // Because not all providers give out a refresh token + // See corresponding OAuth2 spec: + if providerRefreshToken != "" { + q.Set("provider_refresh_token", providerRefreshToken) + } + + rurl = token.AsRedirectURL(rurl, q) + + } + + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + ctx := r.Context() + aud := a.requestAud(ctx, r) + config := a.config + + var user *models.User + var identity *models.Identity + var identityData map[string]interface{} + if userData.Metadata != nil { + identityData = structs.Map(userData.Metadata) + } + + decision, terr := models.DetermineAccountLinking(tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject) + if terr != nil { + return nil, terr + } + + switch decision.Decision { + case models.LinkAccount: + user = decision.User + + if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { + return nil, terr + } + + if terr = user.UpdateUserMetaData(tx, identityData); terr != nil { + return nil, terr + } + + if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + + case models.CreateAccount: + if config.DisableSignup { + return nil, unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{ + Provider: providerType, + Email: decision.CandidateEmail.Email, + Aud: aud, + Data: identityData, + } + + isSSOUser := false + if strings.HasPrefix(decision.LinkingDomain, "sso:") { + isSSOUser = true + } + + // because params above sets no password, this method is not + // computationally hard so it can be used within a database + // transaction + user, terr = params.ToUserModel(isSSOUser) + if terr != nil { + return nil, terr + } + + if user, terr = a.signupNewUser(tx, user); terr != nil { + return nil, terr + } + + if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { + return nil, terr + } + user.Identities = append(user.Identities, *identity) + case models.AccountExists: + user = decision.User + identity = decision.Identities[0] + + identity.IdentityData = identityData + if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil { + return nil, terr + } + if terr = user.UpdateUserMetaData(tx, identityData); terr != nil { + return nil, terr + } + if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + + case models.MultipleAccounts: + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) + + default: + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) + } + + if user.IsBanned() { + return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + } + + // TODO(hf): Expand this boolean with all providers that may not have emails (like X/Twitter, Discord). + hasEmails := providerType != "web3" // intentionally not using len(userData.Emails) != 0 for better backward compatibility control + + if hasEmails && !user.IsConfirmed() { + // The user may have other unconfirmed email + password + // combination, phone or oauth identities. These identities + // need to be removed when a new oauth identity is being added + // to prevent pre-account takeover attacks from happening. + if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { + return nil, internalServerError("Error updating user").WithInternalError(terr) + } + if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": providerType, + }); terr != nil { + return nil, terr + } + // fall through to auto-confirm and issue token + if terr = user.Confirm(tx); terr != nil { + return nil, internalServerError("Error updating user").WithInternalError(terr) + } + } else { + // Some providers, like web3 don't have email data. + // Treat these as if a confirmation email has been + // sent, although the user will be created without an + // email address. + emailConfirmationSent := false + if decision.CandidateEmail.Email != "" { + if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil { + return nil, terr + } + emailConfirmationSent = true + } + if !config.Mailer.AllowUnverifiedEmailSignIns { + if emailConfirmationSent { + return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + } + return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + } + } + } else { + if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": providerType, + }); terr != nil { + return nil, terr + } + } + + return user, nil +} + +func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *provider.UserProvidedData, inviteToken, providerType string) (*models.User, error) { + user, err := models.FindUserByConfirmationToken(tx, inviteToken) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(apierrors.ErrorCodeInviteNotFound, "Invite not found") + } + return nil, internalServerError("Database error finding user").WithInternalError(err) + } + + var emailData *provider.Email + var emails []string + for i, e := range userData.Emails { + emails = append(emails, e.Email) + if user.GetEmail() == e.Email { + emailData = &userData.Emails[i] + break + } + } + + if emailData == nil { + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + } + + var identityData map[string]interface{} + if userData.Metadata != nil { + identityData = structs.Map(userData.Metadata) + } + identity, err := a.createNewIdentity(tx, user, providerType, identityData) + if err != nil { + return nil, err + } + if err := user.UpdateAppMetaData(tx, map[string]interface{}{ + "provider": providerType, + }); err != nil { + return nil, err + } + if err := user.UpdateAppMetaDataProviders(tx); err != nil { + return nil, err + } + if err := user.UpdateUserMetaData(tx, identityData); err != nil { + return nil, internalServerError("Database error updating user").WithInternalError(err) + } + + if err := models.NewAuditLogEntry(r, tx, user, models.InviteAcceptedAction, "", map[string]interface{}{ + "provider": providerType, + }); err != nil { + return nil, err + } + + // an account with a previously unconfirmed email + password + // combination or phone may exist. so now that there is an + // OAuth identity bound to this user, and since they have not + // confirmed their email or phone, they are unaware that a + // potentially malicious door exists into their account; thus + // the password and phone needs to be removed. + if err := user.RemoveUnconfirmedIdentities(tx, identity); err != nil { + return nil, internalServerError("Error updating user").WithInternalError(err) + } + + // confirm because they were able to respond to invite email + if err := user.Confirm(tx); err != nil { + return nil, err + } + return user, nil +} + +func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) { + var state string + switch r.Method { + case http.MethodPost: + state = r.FormValue("state") + default: + state = r.URL.Query().Get("state") + } + if state == "" { + return ctx, badRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing") + } + config := a.config + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) + _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &config.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(config.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") + }) + if err != nil { + return ctx, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return ctx, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") + } + if claims.InviteToken != "" { + ctx = withInviteToken(ctx, claims.InviteToken) + } + if claims.Referrer != "" { + ctx = withExternalReferrer(ctx, claims.Referrer) + } + if claims.FlowStateID != "" { + ctx = withFlowStateID(ctx, claims.FlowStateID) + } + if claims.LinkingTargetID != "" { + linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) + if err != nil { + return nil, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") + } + u, err := models.FindUserByID(a.db, linkingTargetUserID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, unprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + ctx = withTargetUser(ctx, u) + } + ctx = withExternalProviderType(ctx, claims.Provider) + return withSignature(ctx, state), nil +} + +// Provider returns a Provider interface for the given name. +func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) { + config := a.config + name = strings.ToLower(name) + + switch name { + case "apple": + return provider.NewAppleProvider(ctx, config.External.Apple) + case "azure": + return provider.NewAzureProvider(config.External.Azure, scopes) + case "bitbucket": + return provider.NewBitbucketProvider(config.External.Bitbucket) + case "discord": + return provider.NewDiscordProvider(config.External.Discord, scopes) + case "facebook": + return provider.NewFacebookProvider(config.External.Facebook, scopes) + case "figma": + return provider.NewFigmaProvider(config.External.Figma, scopes) + case "fly": + return provider.NewFlyProvider(config.External.Fly, scopes) + case "github": + return provider.NewGithubProvider(config.External.Github, scopes) + case "gitlab": + return provider.NewGitlabProvider(config.External.Gitlab, scopes) + case "google": + return provider.NewGoogleProvider(ctx, config.External.Google, scopes) + case "kakao": + return provider.NewKakaoProvider(config.External.Kakao, scopes) + case "keycloak": + return provider.NewKeycloakProvider(config.External.Keycloak, scopes) + case "linkedin": + return provider.NewLinkedinProvider(config.External.Linkedin, scopes) + case "linkedin_oidc": + return provider.NewLinkedinOIDCProvider(config.External.LinkedinOIDC, scopes) + case "notion": + return provider.NewNotionProvider(config.External.Notion) + case "spotify": + return provider.NewSpotifyProvider(config.External.Spotify, scopes) + case "slack": + return provider.NewSlackProvider(config.External.Slack, scopes) + case "slack_oidc": + return provider.NewSlackOIDCProvider(config.External.SlackOIDC, scopes) + case "twitch": + return provider.NewTwitchProvider(config.External.Twitch, scopes) + case "twitter": + return provider.NewTwitterProvider(config.External.Twitter, scopes) + case "vercel_marketplace": + return provider.NewVercelMarketplaceProvider(config.External.VercelMarketplace, scopes) + case "workos": + return provider.NewWorkOSProvider(config.External.WorkOS) + case "zoom": + return provider.NewZoomProvider(config.External.Zoom) + default: + return nil, fmt.Errorf("Provider %s could not be found", name) + } +} + +func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { + ctx := r.Context() + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(ctx) + err := handler(w, r) + if err != nil { + q := getErrorQueryString(err, errorID, log, u.Query()) + u.RawQuery = q.Encode() + + // TODO: deprecate returning error details in the query fragment + hq := url.Values{} + if q.Get("error") != "" { + hq.Set("error", q.Get("error")) + } + if q.Get("error_description") != "" { + hq.Set("error_description", q.Get("error_description")) + } + if q.Get("error_code") != "" { + hq.Set("error_code", q.Get("error_code")) + } + u.Fragment = hq.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) + } +} + +func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { + switch e := err.(type) { + case *HTTPError: + if e.ErrorCode == apierrors.ErrorCodeSignupDisabled { + q.Set("error", "access_denied") + } else if e.ErrorCode == apierrors.ErrorCodeUserBanned { + q.Set("error", "access_denied") + } else if e.ErrorCode == apierrors.ErrorCodeProviderEmailNeedsVerification { + q.Set("error", "access_denied") + } else if str, ok := oauthErrorMap[e.HTTPStatus]; ok { + q.Set("error", str) + } else { + q.Set("error", "server_error") + } + if e.HTTPStatus >= http.StatusInternalServerError { + e.ErrorID = errorID + // this will get us the stack trace too + log.WithError(e.Cause()).Error(e.Error()) + } else { + log.WithError(e.Cause()).Info(e.Error()) + } + q.Set("error_description", e.Message) + q.Set("error_code", e.ErrorCode) + case *OAuthError: + q.Set("error", e.Err) + q.Set("error_description", e.Description) + log.WithError(e.Cause()).Info(e.Error()) + case ErrorCause: + return getErrorQueryString(e.Cause(), errorID, log, q) + default: + error_type, error_description := "server_error", err.Error() + + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e); pgErr != nil { + error_description = pgErr.Message + if oauthErrorType, ok := oauthErrorMap[pgErr.HttpStatusCode]; ok { + error_type = oauthErrorType + } + } + + q.Set("error", error_type) + q.Set("error_description", error_description) + } + return &q +} + +func (a *API) getExternalRedirectURL(r *http.Request) string { + ctx := r.Context() + config := a.config + if config.External.RedirectURL != "" { + return config.External.RedirectURL + } + if er := getExternalReferrer(ctx); er != "" { + return er + } + return config.SiteURL +} + +func (a *API) createNewIdentity(tx *storage.Connection, user *models.User, providerType string, identityData map[string]interface{}) (*models.Identity, error) { + identity, err := models.NewIdentity(user, providerType, identityData) + if err != nil { + return nil, err + } + + if terr := tx.Create(identity); terr != nil { + return nil, internalServerError("Error creating identity").WithInternalError(terr) + } + + return identity, nil +} diff --git a/api/external_apple_test.go b/internal/api/external_apple_test.go similarity index 82% rename from api/external_apple_test.go rename to internal/api/external_apple_test.go index 8b2e99a98..5a0b4970c 100644 --- a/api/external_apple_test.go +++ b/internal/api/external_apple_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) func (ts *ExternalTestSuite) TestSignupExternalApple() { @@ -17,12 +17,12 @@ func (ts *ExternalTestSuite) TestSignupExternalApple() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Apple.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Apple.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Apple.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("email name", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/internal/api/external_azure_test.go b/internal/api/external_azure_test.go new file mode 100644 index 000000000..aac124c78 --- /dev/null +++ b/internal/api/external_azure_test.go @@ -0,0 +1,269 @@ +package api + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/provider" +) + +const ( + azureUser string = `{"name":"Azure Test","email":"azure@example.com","sub":"azuretestid"}` + azureUserNoEmail string = `{"name":"Azure Test","sub":"azuretestid"}` +) + +func idTokenPrivateKey() *rsa.PrivateKey { + // #nosec + der, err := base64.StdEncoding.DecodeString("MIIEpAIBAAKCAQEAvklrFDsVgbhs3DOQICMqm4xdFoi/MHj/T6XH8S7wXWd0roqdWVarwCLV4y3DILkLre4PzNK+hEY5NAnoAKrsCMyyCb4Wdl8HCdJk4ojDqAig+DJw67imqZoxJMFJyIhfMJhwVK1V8GRUPATn855rygLo7wThahMJeEHNiJr3TtV6Rf35KSs7DuyoWIUSjISYabQozKqIvpdUpTpSqjlOQvjdAxggRyycBZSgLzjWhsA8metnAMO48bX4bgiHLR6Kzu/dfPyEVPfgeYpA2ebIY6GzIUxVS0yX8+ExA6jeLCkuepjLHuz5XCJtd6zzGDXr1eX7nA6ZIeUNdFbWRDnPawIDAQABAoIBABH4Qvl1HvHSJc2hvPGcAJER71SKc2uzcYDnCfu30BEyDO3Sv0tJiQyq/YHnt26mqviw66MPH9jD/PDyIou1mHa4RfPvlJV3IeYGjWprOfbrYbAuq0VHec24dv2el0YtwreHHcyRVfVOtDm6yODTzCAWqEKyNktbIuDNbgiBgetayaJecDRoFMF9TOCeMCL92iZytzAr7fi+JWtLkRS/GZRIBjbr8LJ/ueYoCRmIx3MIw0WdPp7v2ZfeRTxP7LxJZ+MAsrq2pstmZYP7K0305e0bCJX1HexfXLs2Ul7u8zaxrXL8zw4/9+/GMsAeU3ffCVnGz/RKL5+T6iuz2RotjFECgYEA+Xk7DGwRXfDg9xba1GVFGeiC4nybqZw/RfZKcz/RRJWSHRJV/ps1avtbca3B19rjI6rewZMO1NWNv/tI2BdXP8vAKUnI9OHJZ+J/eZzmqDE6qu0v0ddRFUDzCMWE0j8BjrUdy44n4NQgopcv14u0iyr9tuhGO6YXn2SuuvEkZokCgYEAw0PNnT55kpkEhXSp7An2hdBJEub9ST7hS6Kcd8let62/qUZ/t5jWigSkWC1A2bMtH55+LgudIFjiehwVzRs7jym2j4jkKZGonyAX1l9IWgXwKl7Pn49lEQH5Yk6MhnXdyLGoFTzXiUyk/fKvgXX7jow1bD3j6sAc8P495I7TyVMCgYAHg6VJrH+har37805IE3zPWPeIRuSRaUlmnBKGAigVfsPV6FV6w8YKIOQSOn+aNtecnWr0Pa+2rXAFllYNXDaej06Mb9KDvcFJRcM9MIKqEkGIIHjOQ0QH9drcKsbjZk5vs/jfxrpgxULuYstoHKclgff+aGSlK02O2YOB0f2csQKBgQCEC/MdNiWCpKXxFg7fB3HF1i/Eb56zjKlQu7uyKeQ6tG3bLEisQNg8Z5034Apt7gRC0KyluMbeHB2z1BBOLu9dBill8X3SOqVcTpiwKKlF76QVEx622YLQOJSMDXBscYK0+KchDY74U3N0JEzZcI7YPCrYcxYRJy+rLVNvn8LK7wKBgQDE8THsZ589e10F0zDBvPK56o8PJnPeH71sgdM2Co4oLzBJ6g0rpJOKfcc03fLHsoJVOAya9WZeIy6K8+WVdcPTadR07S4p8/tcK1eguu5qlmCUOzswrTKAaJoIHO7cddQp3nySIqgYtkGdHKuvlQDMQkEKJS0meOm+vdeAG2rkaA==") + if err != nil { + panic(err) + } + + privateKey, err := x509.ParsePKCS1PrivateKey(der) + if err != nil { + panic(err) + } + + privateKey.E = 65537 + + return privateKey +} + +func setupAzureOverrideVerifiers() { + provider.OverrideVerifiers["https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/authorize"] = func(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + pk := idTokenPrivateKey() + + return oidc.NewVerifier( + provider.IssuerAzureMicrosoft, + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{ + &pk.PublicKey, + }, + }, + config, + ) + } +} + +func mintIDToken(user string) string { + var idToken struct { + Issuer string `json:"iss"` + IssuedAt int `json:"iat"` + ExpiresAt int `json:"exp"` + Audience string `json:"aud"` + + Sub string `json:"sub,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + XmsEdov any `json:"xms_edov,omitempty"` + } + + if err := json.Unmarshal([]byte(user), &idToken); err != nil { + panic(err) + } + + now := time.Now() + + idToken.Issuer = provider.IssuerAzureMicrosoft + idToken.IssuedAt = int(now.Unix()) + idToken.ExpiresAt = int(now.Unix() + 60*60) + idToken.Audience = "testclientid" + + header := base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"RS256"}`)) + + data, err := json.Marshal(idToken) + if err != nil { + panic(err) + } + + payload := base64.RawURLEncoding.EncodeToString(data) + sum := sha256.Sum256([]byte(header + "." + payload)) + + pk := idTokenPrivateKey() + sig, err := rsa.SignPKCS1v15(nil, pk, crypto.SHA256, sum[:]) + if err != nil { + panic(err) + } + + token := header + "." + payload + "." + base64.RawURLEncoding.EncodeToString(sig) + + return token +} + +func (ts *ExternalTestSuite) TestSignupExternalAzure() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=azure", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Azure.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Azure.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("openid", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("azure", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func AzureTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2/v2.0/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Azure.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"azure_token","expires_in":100000,"id_token":%q}`, mintIDToken(user)) + default: + w.WriteHeader(500) + ts.Fail("unknown azure oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Azure.URL = server.URL + ts.Config.External.Azure.ApiURL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalAzure_AuthorizationCode() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = false + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoUser() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "azure@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoEmail() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "azure@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupSuccessWithPrimaryEmail() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + + ts.createUser("azuretestid", "azure@example.com", "Azure Test", "", "") + + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureSuccessWhenMatchingToken() { + setupAzureOverrideVerifiers() + + // name should be populated from Azure API + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenNoMatchingToken() { + setupAzureOverrideVerifiers() + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "azure", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenWrongToken() { + setupAzureOverrideVerifiers() + + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "azure", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenEmailDoesntMatch() { + setupAzureOverrideVerifiers() + + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test", "email":"other@example.com", "avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/api/external_bitbucket_test.go b/internal/api/external_bitbucket_test.go similarity index 97% rename from api/external_bitbucket_test.go rename to internal/api/external_bitbucket_test.go index ab48a1e1b..66b3bd4df 100644 --- a/api/external_bitbucket_test.go +++ b/internal/api/external_bitbucket_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( @@ -22,12 +22,12 @@ func (ts *ExternalTestSuite) TestSignupExternalBitbucket() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Bitbucket.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Bitbucket.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Bitbucket.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("account email", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/api/external_discord_test.go b/internal/api/external_discord_test.go similarity index 96% rename from api/external_discord_test.go rename to internal/api/external_discord_test.go index 057cd9c59..7b6be8d34 100644 --- a/api/external_discord_test.go +++ b/internal/api/external_discord_test.go @@ -6,11 +6,11 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( - discordUser string = `{"id":"discordTestId","avatar":"abc","email":"discord@example.com","username":"Discord Test","verified":true}}` + discordUser string = `{"id":"discordTestId","avatar":"abc","email":"discord@example.com","username":"Discord Test","verified":true,"discriminator":"0001"}}` discordUserWrongEmail string = `{"id":"discordTestId","avatar":"abc","email":"other@example.com","username":"Discord Test","verified":true}}` discordUserNoEmail string = `{"id":"discordTestId","avatar":"abc","username":"Discord Test","verified":true}}` ) @@ -24,12 +24,12 @@ func (ts *ExternalTestSuite) TestSignupExternalDiscord() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Discord.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Discord.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Discord.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("email identify", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/api/external_facebook_test.go b/internal/api/external_facebook_test.go similarity index 97% rename from api/external_facebook_test.go rename to internal/api/external_facebook_test.go index 253715438..c1864bb9c 100644 --- a/api/external_facebook_test.go +++ b/internal/api/external_facebook_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( @@ -24,12 +24,12 @@ func (ts *ExternalTestSuite) TestSignupExternalFacebook() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Facebook.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Facebook.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Facebook.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("email", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/internal/api/external_figma_test.go b/internal/api/external_figma_test.go new file mode 100644 index 000000000..3c91f1174 --- /dev/null +++ b/internal/api/external_figma_test.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalFigma() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=figma", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Figma.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Figma.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("files:read", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("figma", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func FigmaTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, email string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Figma.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"figma_token","expires_in":100000,"refresh_token":"figma_token"}`) + case "/v1/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"id":"figma-test-id","email":"%s","handle":"Figma Test","img_url":"http://example.com/avatar"}`, email) + default: + w.WriteHeader(500) + ts.Fail("unknown figma oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Figma.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalFigma_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigma_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "figma", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "figma@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "figma_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "figma@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "figma@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("figma-test-id", "figma@example.com", "Figma Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaSuccessWhenMatchingToken() { + // name and avatar should be populated from Figma API + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "figma", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenWrongToken() { + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "figma", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenEmailDoesntMatch() { + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "other@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "figma@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "figma", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/internal/api/external_fly_test.go b/internal/api/external_fly_test.go new file mode 100644 index 000000000..cf357c97b --- /dev/null +++ b/internal/api/external_fly_test.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalFly() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=fly", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Fly.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Fly.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("read", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("fly", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func FlyTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, email string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Fly.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"fly_token","expires_in":100000,"refresh_token":"fly_refresh_token"}`) + case "/oauth/token/info": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"resource_owner_id":"test_resource_owner_id","scope":["read"],"expires_in":1111,"application":{"uid":"test_app_uid"},"created_at":1696003692,"user_id":"test_user_id","user_name":"test_user","email":"%s","organizations":[{"id":"test_org_id","role":"test"}]}`, email) + default: + w.WriteHeader(500) + ts.Fail("unknown fly oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Fly.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalFly_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFly_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "fly", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "fly@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "fly_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", email) +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "fly@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("test_user_id", "fly@example.com", "test_user", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlySuccessWhenMatchingToken() { + // name and avatar should be populated from fly API + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "fly", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenWrongToken() { + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "fly", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenEmailDoesntMatch() { + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "other@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "fly@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "fly", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/api/external_github_test.go b/internal/api/external_github_test.go similarity index 71% rename from api/external_github_test.go rename to internal/api/external_github_test.go index 9ebcda49b..7b9d31e89 100644 --- a/api/external_github_test.go +++ b/internal/api/external_github_test.go @@ -1,15 +1,19 @@ package api import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "time" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/models" + jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" ) func (ts *ExternalTestSuite) TestSignupExternalGithub() { @@ -21,12 +25,12 @@ func (ts *ExternalTestSuite) TestSignupExternalGithub() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Github.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Github.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Github.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("user:email", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -77,6 +81,80 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHub_AuthorizationCode() { assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") } +func (ts *ExternalTestSuite) TestSignupExternalGitHub_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + { + desc: "Plain", + codeChallengeMethod: "plain", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "github", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "github@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "github_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } + +} + func (ts *ExternalTestSuite) TestSignupExternalGitHubDisableSignupErrorWhenNoUser() { ts.Config.DisableSignup = true tokenCount, userCount := 0, 0 @@ -189,6 +267,7 @@ func (ts *ExternalTestSuite) TestInviteTokenExternalGitHubErrorWhenEmailDoesntMa } func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { + ts.Config.Mailer.AllowUnverifiedEmailSignIns = false tokenCount, userCount := 0, 0 code := "authcode" emails := `[{"email":"github@example.com", "primary": true, "verified": false}]` @@ -197,12 +276,7 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { u := performAuthorization(ts, "github", code, "") - v, err := url.ParseQuery(u.Fragment) - ts.Require().NoError(err) - ts.Equal("unauthorized_client", v.Get("error")) - ts.Equal("401", v.Get("error_code")) - ts.Equal("Unverified email with github", v.Get("error_description")) - assertAuthorizationFailure(ts, u, "", "", "") + assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { @@ -215,12 +289,12 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { u := performAuthorization(ts, "github", code, "") assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") - user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "github@example.com", ts.Config.JWT.Aud) + user, err := models.FindUserByEmailAndAudience(ts.API.db, "github@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) t := time.Now().Add(24 * time.Hour) user.BannedUntil = &t require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/api/external_gitlab_test.go b/internal/api/external_gitlab_test.go similarity index 97% rename from api/external_gitlab_test.go rename to internal/api/external_gitlab_test.go index 8a8b0fbf0..5a14a0a2b 100644 --- a/api/external_gitlab_test.go +++ b/internal/api/external_gitlab_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( @@ -24,12 +24,12 @@ func (ts *ExternalTestSuite) TestSignupExternalGitlab() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Gitlab.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Gitlab.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Gitlab.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("read_user", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/api/external_google_test.go b/internal/api/external_google_test.go similarity index 87% rename from api/external_google_test.go rename to internal/api/external_google_test.go index 8992d0a65..7b3b6d156 100644 --- a/api/external_google_test.go +++ b/internal/api/external_google_test.go @@ -1,21 +1,26 @@ package api import ( + "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/provider" ) const ( googleUser string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","email":"google@example.com","verified_email":true}}` googleUserWrongEmail string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","email":"other@example.com","verified_email":true}}` - googleUserNoEmail string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","verified_email":true}}` + googleUserNoEmail string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","verified_email":false}}` ) func (ts *ExternalTestSuite) TestSignupExternalGoogle() { + provider.ResetGoogleProvider() + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=google", nil) w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) @@ -24,12 +29,12 @@ func (ts *ExternalTestSuite) TestSignupExternalGoogle() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Google.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Google.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Google.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("email profile", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -40,8 +45,17 @@ func (ts *ExternalTestSuite) TestSignupExternalGoogle() { } func GoogleTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + provider.ResetGoogleProvider() + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { + case "/.well-known/openid-configuration": + w.Header().Add("Content-Type", "application/json") + require.NoError(ts.T(), json.NewEncoder(w).Encode(map[string]any{ + "issuer": server.URL, + "token_endpoint": server.URL + "/o/oauth2/token", + })) case "/o/oauth2/token": *tokenCount++ ts.Equal(code, r.FormValue("code")) @@ -60,7 +74,7 @@ func GoogleTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *in } })) - ts.Config.External.Google.URL = server.URL + provider.OverrideGoogleProvider(server.URL, server.URL+"/userinfo/v2/me") return server } diff --git a/internal/api/external_kakao_test.go b/internal/api/external_kakao_test.go new file mode 100644 index 000000000..729f723a7 --- /dev/null +++ b/internal/api/external_kakao_test.go @@ -0,0 +1,238 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalKakao() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=kakao", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Kakao.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Kakao.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("kakao", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func KakaoTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, emails string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Kakao.RedirectURI, r.FormValue("redirect_uri")) + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"kakao_token","expires_in":100000}`) + case "/v2/user/me": + *userCount++ + var emailList []provider.Email + if err := json.Unmarshal([]byte(emails), &emailList); err != nil { + ts.Fail("Invalid email json %s", emails) + } + + var email *provider.Email + + for i, e := range emailList { + if len(e.Email) > 0 { + email = &emailList[i] + break + } + } + + w.Header().Add("Content-Type", "application/json") + if email != nil { + fmt.Fprintf(w, ` + { + "id":123, + "kakao_account": { + "profile": { + "nickname":"Kakao Test", + "profile_image_url":"http://example.com/avatar" + }, + "email": "%v", + "is_email_valid": %v, + "is_email_verified": %v + } + }`, email.Email, email.Verified, email.Verified) + } else { + fmt.Fprint(w, ` + { + "id":123, + "kakao_account": { + "profile": { + "nickname":"Kakao Test", + "profile_image_url":"http://example.com/avatar" + } + } + }`) + } + default: + w.WriteHeader(500) + ts.Fail("unknown kakao oauth call %s", r.URL.Path) + } + })) + ts.Config.External.Kakao.URL = server.URL + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalKakao_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + u := performAuthorization(ts, "kakao", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "kakao@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "kakao@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("123", "kakao@example.com", "Kakao Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoSuccessWhenMatchingToken() { + // name and avatar should be populated from Kakao API + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "kakao", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenWrongToken() { + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "kakao", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenEmailDoesntMatch() { + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"other@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenVerifiedFalse() { + ts.Config.Mailer.AllowUnverifiedEmailSignIns = false + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": false}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "access_denied", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "kakao@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "kakao", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/internal/api/external_keycloak_test.go b/internal/api/external_keycloak_test.go new file mode 100644 index 000000000..a0952eaac --- /dev/null +++ b/internal/api/external_keycloak_test.go @@ -0,0 +1,182 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + keycloakUser string = `{"sub": "keycloaktestid", "name": "Keycloak Test", "email": "keycloak@example.com", "preferred_username": "keycloak", "email_verified": true}` + keycloakUserNoEmail string = `{"sub": "keycloaktestid", "name": "Keycloak Test", "preferred_username": "keycloak", "email_verified": false}` +) + +func (ts *ExternalTestSuite) TestSignupExternalKeycloak() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=keycloak", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Keycloak.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Keycloak.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("profile email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("keycloak", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func KeycloakTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/protocol/openid-connect/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Keycloak.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"keycloak_token","expires_in":100000}`) + case "/protocol/openid-connect/userinfo": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown keycloak oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Keycloak.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakWithoutURLSetup() { + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + ts.Config.External.Keycloak.URL = "" + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", code) + ts.Equal(w.Code, http.StatusBadRequest) +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloak_AuthorizationCode() { + ts.Config.DisableSignup = false + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "keycloak@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupErrorWhenNoEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "keycloak@example.com") + +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakSuccessWhenMatchingToken() { + // name and avatar should be populated from Keycloak API + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test","avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenWrongToken() { + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test","avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenEmailDoesntMatch() { + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test", "email":"other@example.com", "avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/api/external_linkedin_test.go b/internal/api/external_linkedin_test.go similarity index 78% rename from api/external_linkedin_test.go rename to internal/api/external_linkedin_test.go index cb5392482..fe49932e5 100644 --- a/api/external_linkedin_test.go +++ b/internal/api/external_linkedin_test.go @@ -6,13 +6,14 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( - linkedinUser string = `{"id":"linkedinTestId","firstName":{"localized":{"en_US":"Linkedin"},"preferredLocale":{"country":"US","language":"en"}},"lastName":{"localized":{"en_US":"Test"},"preferredLocale":{"country":"US","language":"en"}},"profilePicture":{"displayImage~":{"elements":[{"identifiers":[{"identifier":"http://example.com/avatar"}]}]}}}` - linkedinEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "linkedin@example.com"}}]}` - linkedinWrongEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "other@example.com"}}]}` + linkedinUser string = `{"id":"linkedinTestId","firstName":{"localized":{"en_US":"Linkedin"},"preferredLocale":{"country":"US","language":"en"}},"lastName":{"localized":{"en_US":"Test"},"preferredLocale":{"country":"US","language":"en"}},"profilePicture":{"displayImage~":{"elements":[{"identifiers":[{"identifier":"http://example.com/avatar"}]}]}}}` + linkedinUserNoProfilePic string = `{"id":"linkedinTestId","firstName":{"localized":{"en_US":"Linkedin"},"preferredLocale":{"country":"US","language":"en"}},"lastName":{"localized":{"en_US":"Test"},"preferredLocale":{"country":"US","language":"en"}},"profilePicture":{"displayImage~":{"elements":[]}}}` + linkedinEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "linkedin@example.com"}}]}` + linkedinWrongEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "other@example.com"}}]}` ) func (ts *ExternalTestSuite) TestSignupExternalLinkedin() { @@ -24,12 +25,12 @@ func (ts *ExternalTestSuite) TestSignupExternalLinkedin() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Linkedin.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Linkedin.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Linkedin.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("r_emailaddress r_liteprofile", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -156,3 +157,14 @@ func (ts *ExternalTestSuite) TestInviteTokenExternalLinkedinErrorWhenEmailDoesnt assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") } + +func (ts *ExternalTestSuite) TestSignupExternalLinkedin_MissingProfilePic() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUserNoProfilePic, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "linkedin@example.com", "Linkedin Test", "linkedinTestId", "") +} diff --git a/api/external_notion_test.go b/internal/api/external_notion_test.go similarity index 96% rename from api/external_notion_test.go rename to internal/api/external_notion_test.go index be82e09ed..268e4492b 100644 --- a/api/external_notion_test.go +++ b/internal/api/external_notion_test.go @@ -6,13 +6,13 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( notionUser string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","person":{"email":"notion@example.com"},"verified_email":true}}}}` notionUserWrongEmail string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","person":{"email":"other@example.com"},"verified_email":true}}}}` - notionUserNoEmail string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","verified_email":true}}}}}` + notionUserNoEmail string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","verified_email":true}}}}` ) func (ts *ExternalTestSuite) TestSignupExternalNotion() { @@ -24,11 +24,11 @@ func (ts *ExternalTestSuite) TestSignupExternalNotion() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Notion.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Notion.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Notion.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) diff --git a/api/external_oauth.go b/internal/api/external_oauth.go similarity index 59% rename from api/external_oauth.go rename to internal/api/external_oauth.go index 3a4899efc..3b8f2f904 100644 --- a/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,35 +2,29 @@ package api import ( "context" + "fmt" "net/http" "net/url" "github.com/mrjones/oauth" - "github.com/netlify/gotrue/api/provider" - "github.com/netlify/gotrue/storage" "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" ) // OAuthProviderData contains the userData and token returned by the oauth provider type OAuthProviderData struct { - userData *provider.UserProvidedData - token string + userData *provider.UserProvidedData + token string + refreshToken string + code string } -// loadOAuthState parses the `state` query parameter as a JWS payload, +// loadFlowState parses the `state` query parameter as a JWS payload, // extracting the provider requested -func (a *API) loadOAuthState(w http.ResponseWriter, r *http.Request) (context.Context, error) { - var state string - if r.Method == http.MethodPost { - state = r.FormValue("state") - } else { - state = r.URL.Query().Get("state") - } - - if state == "" { - return nil, badRequestError("OAuth state parameter missing") - } - +func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() oauthToken := r.URL.Query().Get("oauth_token") if oauthToken != "" { @@ -40,7 +34,21 @@ func (a *API) loadOAuthState(w http.ResponseWriter, r *http.Request) (context.Co if oauthVerifier != "" { ctx = withOAuthVerifier(ctx, oauthVerifier) } - return a.loadExternalState(ctx, state) + + var err error + ctx, err = a.loadExternalState(ctx, r) + if err != nil { + u, uerr := url.ParseRequestURI(a.config.SiteURL) + if uerr != nil { + return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr) + } + + q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return ctx, err } func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { @@ -58,15 +66,15 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } - log := getLogEntry(r) + log := observability.GetLogEntry(r).Entry log.WithFields(logrus.Fields{ "provider": providerType, "code": oauthCode, @@ -79,7 +87,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s userData, err := oAuthProvider.GetUserData(ctx, token) if err != nil { - return nil, internalServerError("Error getting user email from external provider").WithInternalError(err) + return nil, internalServerError("Error getting user profile from external provider").WithInternalError(err) } switch externalProvider := oAuthProvider.(type) { @@ -95,34 +103,26 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s } return &OAuthProviderData{ - userData: userData, - token: token.AccessToken, + userData: userData, + token: token.AccessToken, + refreshToken: token.RefreshToken, + code: oauthCode, }, nil } -func (a *API) oAuth1Callback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { +func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) - } - value, err := storage.GetFromSession(providerType, r) - if err != nil { - return &OAuthProviderData{}, err + return nil, badRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) var accessToken *oauth.AccessToken var userData *provider.UserProvidedData if twitterProvider, ok := oAuthProvider.(*provider.TwitterProvider); ok { - requestToken, err := twitterProvider.Unmarshal(value) - if err != nil { - return &OAuthProviderData{}, err - } - if requestToken.Token != oauthToken { - return nil, internalServerError("Request token doesn't match token in callback") - } - twitterProvider.OauthVerifier = oauthVerifier - accessToken, err = twitterProvider.Consumer.AuthorizeToken(requestToken, oauthVerifier) + accessToken, err = twitterProvider.Consumer.AuthorizeToken(&oauth.RequestToken{ + Token: oauthToken, + }, oauthVerifier) if err != nil { return nil, internalServerError("Unable to retrieve access token").WithInternalError(err) } @@ -133,8 +133,9 @@ func (a *API) oAuth1Callback(ctx context.Context, r *http.Request, providerType } return &OAuthProviderData{ - userData: userData, - token: accessToken.Token, + userData: userData, + token: accessToken.Token, + refreshToken: "", }, nil } @@ -150,6 +151,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/external_slack_oidc_test.go b/internal/api/external_slack_oidc_test.go new file mode 100644 index 000000000..acd2e784d --- /dev/null +++ b/internal/api/external_slack_oidc_test.go @@ -0,0 +1,33 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +func (ts *ExternalTestSuite) TestSignupExternalSlackOIDC() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=slack_oidc", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Slack.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Slack.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("profile email openid", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("slack_oidc", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} diff --git a/internal/api/external_test.go b/internal/api/external_test.go new file mode 100644 index 000000000..bef89d736 --- /dev/null +++ b/internal/api/external_test.go @@ -0,0 +1,254 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type ExternalTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestExternal(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &ExternalTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *ExternalTestSuite) SetupTest() { + ts.Config.DisableSignup = false + ts.Config.Mailer.Autoconfirm = false + + models.TruncateAll(ts.API.db) +} + +func (ts *ExternalTestSuite) createUser(providerId string, email string, name string, avatar string, confirmationToken string) (*models.User, error) { + // Cleanup existing user, if they already exist + if u, _ := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud); u != nil { + require.NoError(ts.T(), ts.API.db.Destroy(u), "Error deleting user") + } + + userData := map[string]interface{}{"provider_id": providerId, "full_name": name} + if avatar != "" { + userData["avatar_url"] = avatar + } + u, err := models.NewUser("", email, "test", ts.Config.JWT.Aud, userData) + + if confirmationToken != "" { + u.ConfirmationToken = confirmationToken + } + ts.Require().NoError(err, "Error making new user") + ts.Require().NoError(ts.API.db.Create(u), "Error creating user") + + if confirmationToken != "" { + ts.Require().NoError(models.CreateOneTimeToken(ts.API.db, u.ID, email, u.ConfirmationToken, models.ConfirmationToken), "Error creating one-time confirmation/invite token") + } + + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": email, + }) + ts.Require().NoError(err) + ts.Require().NoError(ts.API.db.Create(i), "Error creating identity") + + return u, err +} + +func performAuthorizationRequest(ts *ExternalTestSuite, provider string, inviteToken string) *httptest.ResponseRecorder { + authorizeURL := "http://localhost/authorize?provider=" + provider + if inviteToken != "" { + authorizeURL = authorizeURL + "&invite_token=" + inviteToken + } + + req := httptest.NewRequest(http.MethodGet, authorizeURL, nil) + req.Header.Set("Referer", "https://example.netlify.com/admin") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + return w +} + +func performPKCEAuthorizationRequest(ts *ExternalTestSuite, provider, codeChallenge, codeChallengeMethod string) *httptest.ResponseRecorder { + authorizeURL := "http://localhost/authorize?provider=" + provider + if codeChallenge != "" { + authorizeURL = authorizeURL + "&code_challenge=" + codeChallenge + "&code_challenge_method=" + codeChallengeMethod + } + + req := httptest.NewRequest(http.MethodGet, authorizeURL, nil) + req.Header.Set("Referer", "https://example.supabase.com/admin") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + return w +} + +func performPKCEAuthorization(ts *ExternalTestSuite, provider, code, codeChallenge, codeChallengeMethod string) *url.URL { + w := performPKCEAuthorizationRequest(ts, provider, codeChallenge, codeChallengeMethod) + ts.Require().Equal(http.StatusFound, w.Code) + // Get code and state from the redirect + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + // Use the code to get a token + req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + + return u + +} + +func performAuthorization(ts *ExternalTestSuite, provider string, code string, inviteToken string) *url.URL { + w := performAuthorizationRequest(ts, provider, inviteToken) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + + // auth server callback + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + ts.Require().Equal("/admin", u.Path) + + return u +} + +func assertAuthorizationSuccess(ts *ExternalTestSuite, u *url.URL, tokenCount int, userCount int, email string, name string, providerId string, avatar string) { + // ensure redirect has #access_token=... + v, err := url.ParseQuery(u.RawQuery) + ts.Require().NoError(err) + ts.Require().Empty(v.Get("error_description")) + ts.Require().Empty(v.Get("error")) + + v, err = url.ParseQuery(u.Fragment) + ts.Require().NoError(err) + ts.NotEmpty(v.Get("access_token")) + ts.NotEmpty(v.Get("refresh_token")) + ts.NotEmpty(v.Get("expires_in")) + ts.Equal("bearer", v.Get("token_type")) + + ts.Equal(1, tokenCount) + if userCount > -1 { + ts.Equal(1, userCount) + } + + // ensure user has been created with metadata + user, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + ts.Require().NoError(err) + ts.Equal(providerId, user.UserMetaData["provider_id"]) + ts.Equal(name, user.UserMetaData["full_name"]) + if avatar == "" { + ts.Equal(nil, user.UserMetaData["avatar_url"]) + } else { + ts.Equal(avatar, user.UserMetaData["avatar_url"]) + } +} + +func assertAuthorizationFailure(ts *ExternalTestSuite, u *url.URL, errorDescription string, errorType string, email string) { + // ensure new sign ups error + v, err := url.ParseQuery(u.RawQuery) + ts.Require().NoError(err) + ts.Require().Equal(errorDescription, v.Get("error_description")) + ts.Require().Equal(errorType, v.Get("error")) + + v, err = url.ParseQuery(u.Fragment) + ts.Require().NoError(err) + ts.Empty(v.Get("access_token")) + ts.Empty(v.Get("refresh_token")) + ts.Empty(v.Get("expires_in")) + ts.Empty(v.Get("token_type")) + + // ensure user is nil + user, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + ts.Require().Error(err, "User not found") + ts.Require().Nil(user) +} + +// TestSignupExternalUnsupported tests API /authorize for an unsupported external provider +func (ts *ExternalTestSuite) TestSignupExternalUnsupported() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=external", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Equal(w.Code, http.StatusBadRequest) +} + +func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() { + // Request with invalid external provider + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=external", nil) + w := httptest.NewRecorder() + cases := []struct { + Desc string + RedirectURL string + QueryParams []string + ErrorMessage string + }{ + { + Desc: "Should preserve redirect query params on error", + RedirectURL: "http://example.com/path?paramforpreservation=value2", + QueryParams: []string{"paramforpreservation"}, + ErrorMessage: "invalid_request", + }, + { + Desc: "Error param should be overwritten", + RedirectURL: "http://example.com/path?error=abc", + QueryParams: []string{"error"}, + ErrorMessage: "invalid_request", + }, + } + for _, c := range cases { + parsedURL, err := url.Parse(c.RedirectURL) + require.Equal(ts.T(), err, nil) + + redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) + + parsedParams, err := url.ParseQuery(parsedURL.RawQuery) + require.Equal(ts.T(), err, nil) + + // An error and description should be returned + expectedQueryParams := append(c.QueryParams, "error", "error_description") + + for _, expectedQueryParam := range expectedQueryParams { + val, exists := parsedParams[expectedQueryParam] + require.True(ts.T(), exists) + if expectedQueryParam == "error" { + require.Equal(ts.T(), val[0], c.ErrorMessage) + } + } + } +} diff --git a/api/external_twitch_test.go b/internal/api/external_twitch_test.go similarity index 94% rename from api/external_twitch_test.go rename to internal/api/external_twitch_test.go index a4e473cae..694a5ff68 100644 --- a/api/external_twitch_test.go +++ b/internal/api/external_twitch_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "net/url" - jwt "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) const ( @@ -23,12 +23,12 @@ func (ts *ExternalTestSuite) TestSignupExternalTwitch() { ts.Require().NoError(err, "redirect url parse failed") q := u.Query() ts.Equal(ts.Config.External.Twitch.RedirectURI, q.Get("redirect_uri")) - ts.Equal(ts.Config.External.Twitch.ClientID, q.Get("client_id")) + ts.Equal(ts.Config.External.Twitch.ClientID, []string{q.Get("client_id")}) ts.Equal("code", q.Get("response_type")) ts.Equal("user:read:email", q.Get("scope")) claims := ExternalProviderClaims{} - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -107,7 +107,7 @@ func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupErrorWhenEmpty func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupSuccessWithPrimaryEmail() { ts.Config.DisableSignup = true - ts.createUser("twitchTestId", "twitch@example.com", "Twitch Test", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8", "") + ts.createUser("twitchTestId", "twitch@example.com", "Twitch user", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8", "") tokenCount, userCount := 0, 0 code := "authcode" @@ -116,7 +116,7 @@ func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupSuccessWithPri u := performAuthorization(ts, "twitch", code, "") - assertAuthorizationSuccess(ts, u, tokenCount, userCount, "twitch@example.com", "Twitch Test", "twitchTestId", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "twitch@example.com", "Twitch user", "twitchTestId", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8") } func (ts *ExternalTestSuite) TestInviteTokenExternalTwitchSuccessWhenMatchingToken() { diff --git a/api/external_twitter_test.go b/internal/api/external_twitter_test.go similarity index 100% rename from api/external_twitter_test.go rename to internal/api/external_twitter_test.go diff --git a/internal/api/external_workos_test.go b/internal/api/external_workos_test.go new file mode 100644 index 000000000..eedd5b00e --- /dev/null +++ b/internal/api/external_workos_test.go @@ -0,0 +1,221 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + workosUser string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","email":"workos@example.com","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` + workosUserWrongEmail string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","email":"other@example.com","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` + workosUserNoEmail string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithConnection() { + connection := "test_connection_id" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&connection=%s", connection), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(connection, q.Get("connection")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithOrganization() { + organization := "test_organization_id" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&organization=%s", organization), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(organization, q.Get("organization")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithProvider() { + provider := "test_provider" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&workos_provider=%s", provider), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(provider, q.Get("provider")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func WorkosTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sso/token": + // WorkOS returns the user data along with the token. + *tokenCount++ + *userCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.WorkOS.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"workos_token","expires_in":100000,"profile":%s}`, user) + default: + fmt.Printf("%s", r.URL.Path) + w.WriteHeader(500) + ts.Fail("unknown workos oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.WorkOS.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosAuthorizationCode() { + ts.Config.DisableSignup = false + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "workos@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "workos@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("test_prof_workos", "workos@example.com", "John Doe", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosSuccessWhenMatchingToken() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "workos", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenWrongToken() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "workos", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenEmailDoesntMatch() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/internal/api/external_zoom_test.go b/internal/api/external_zoom_test.go new file mode 100644 index 000000000..ea3f15c4d --- /dev/null +++ b/internal/api/external_zoom_test.go @@ -0,0 +1,167 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + zoomUser string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","email": "zoom@example.com","verified": 1,"pic_url":"http://example.com/avatar"}` + zoomUserWrongEmail string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","email": "other@example.com","verified": 1,"pic_url":"http://example.com/avatar"}` + zoomUserNoEmail string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","verified": 1,"pic_url":"http://example.com/avatar"}` +) + +func (ts *ExternalTestSuite) TestSignupExternalZoom() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=zoom", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Zoom.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Zoom.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("zoom", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func ZoomTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Zoom.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"zoom_token","expires_in":100000}`) + case "/v2/users/me": + *userCount++ + ts.Contains(r.Header, "Authorization") + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown zoom oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Zoom.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomAuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "zoom@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "zoom@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("zoomUserId", "zoom@example.com", "John Doe", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomSuccessWhenMatchingToken() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "zoom", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenWrongToken() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "zoom", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenEmailDoesntMatch() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/internal/api/helpers.go b/internal/api/helpers.go new file mode 100644 index 000000000..965e51f81 --- /dev/null +++ b/internal/api/helpers.go @@ -0,0 +1,107 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/security" + + "github.com/supabase/auth/internal/utilities" +) + +func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { + w.Header().Set("Content-Type", "application/json") + b, err := json.Marshal(obj) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) + } + w.WriteHeader(status) + _, err = w.Write(b) + return err +} + +func isAdmin(u *models.User, config *conf.GlobalConfiguration) bool { + return config.JWT.Aud == u.Aud && u.HasRole(config.JWT.AdminGroupName) +} + +func (a *API) requestAud(ctx context.Context, r *http.Request) string { + config := a.config + // First check for an audience in the header + if aud := r.Header.Get(audHeaderName); aud != "" { + return aud + } + + // Then check the token + claims := getClaims(ctx) + + if claims != nil { + aud, _ := claims.GetAudience() + if len(aud) != 0 && aud[0] != "" { + return aud[0] + } + } + + // Finally, return the default if none of the above methods are successful + return config.JWT.Aud +} + +func isStringInSlice(checkValue string, list []string) bool { + for _, val := range list { + if val == checkValue { + return true + } + } + return false +} + +type RequestParams interface { + AdminUserParams | + CreateSSOProviderParams | + EnrollFactorParams | + GenerateLinkParams | + IdTokenGrantParams | + InviteParams | + OtpParams | + PKCEGrantParams | + PasswordGrantParams | + RecoverParams | + RefreshTokenGrantParams | + ResendConfirmationParams | + SignupParams | + SingleSignOnParams | + SmsParams | + Web3GrantParams | + UserUpdateParams | + VerifyFactorParams | + VerifyParams | + adminUserUpdateFactorParams | + adminUserDeleteParams | + security.GotrueRequest | + ChallengeFactorParams | + + struct { + Email string `json:"email"` + Phone string `json:"phone"` + } | + struct { + Email string `json:"email"` + } +} + +// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided +func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { + body, err := utilities.GetBodyBytes(r) + if err != nil { + return internalServerError("Could not read body into byte slice").WithInternalError(err) + } + if err := json.Unmarshal(body, params); err != nil { + return badRequestError(apierrors.ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) + } + return nil +} diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go new file mode 100644 index 000000000..b75cab833 --- /dev/null +++ b/internal/api/helpers_test.go @@ -0,0 +1,152 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" +) + +func TestIsValidCodeChallenge(t *testing.T) { + cases := []struct { + challenge string + isValid bool + expectedError error + }{ + { + challenge: "invalid", + isValid: false, + expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + }, + { + challenge: "codechallengecontainsinvalidcharacterslike@$^&*", + isValid: false, + expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + }, + { + challenge: "validchallengevalidchallengevalidchallengevalidchallenge", + isValid: true, + expectedError: nil, + }, + } + + for _, c := range cases { + valid, err := isValidCodeChallenge(c.challenge) + require.Equal(t, c.isValid, valid) + require.Equal(t, c.expectedError, err) + } +} + +func TestIsValidPKCEParams(t *testing.T) { + cases := []struct { + challengeMethod string + challenge string + expected error + }{ + { + challengeMethod: "", + challenge: "", + expected: nil, + }, + { + challengeMethod: "test", + challenge: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest", + expected: nil, + }, + { + challengeMethod: "test", + challenge: "", + expected: badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + }, + { + challengeMethod: "", + challenge: "test", + expected: badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := validatePKCEParams(c.challengeMethod, c.challenge) + require.Equal(t, c.expected, err) + }) + } +} + +func TestRequestAud(ts *testing.T) { + mockAPI := API{ + config: &conf.GlobalConfiguration{ + JWT: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + }, + }, + } + + cases := []struct { + desc string + headers map[string]string + payload map[string]interface{} + expectedAud string + }{ + { + desc: "Valid audience slice", + headers: map[string]string{ + audHeaderName: "my_custom_aud", + }, + payload: map[string]interface{}{ + "aud": "authenticated", + }, + expectedAud: "my_custom_aud", + }, + { + desc: "Valid custom audience", + payload: map[string]interface{}{ + "aud": "my_custom_aud", + }, + expectedAud: "my_custom_aud", + }, + { + desc: "Invalid audience", + payload: map[string]interface{}{ + "aud": "", + }, + expectedAud: mockAPI.config.JWT.Aud, + }, + { + desc: "Missing audience", + payload: map[string]interface{}{ + "sub": "d6044b6e-b0ec-4efe-a055-0d2d6ff1dbd8", + }, + expectedAud: mockAPI.config.JWT.Aud, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func(t *testing.T) { + claims := jwt.MapClaims(c.payload) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString([]byte(mockAPI.config.JWT.Secret)) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer: %s", signed)) + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // set the token in the request context for requestAud + ctx, err := mockAPI.parseJWTClaims(signed, req) + require.NoError(t, err) + aud := mockAPI.requestAud(ctx, req) + require.Equal(t, c.expectedAud, aud) + }) + } + +} diff --git a/internal/api/hooks.go b/internal/api/hooks.go new file mode 100644 index 000000000..57d21477a --- /dev/null +++ b/internal/api/hooks.go @@ -0,0 +1,406 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "net" + "net/http" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +const ( + DefaultHTTPHookTimeout = 5 * time.Second + DefaultHTTPHookRetries = 3 + HTTPHookBackoffDuration = 2 * time.Second + PayloadLimit = 200 * 1024 // 200KB +) + +func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + db := a.db.WithContext(ctx) + + request, err := json.Marshal(input) + if err != nil { + panic(err) + } + + var response []byte + invokeHookFunc := func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun + if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil { + return terr + } + + if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil { + return terr + } + + // reset the timeout + if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { + return terr + } + + return nil + } + + if tx != nil { + if err := invokeHookFunc(tx); err != nil { + return nil, err + } + } else { + if err := db.Transaction(invokeHookFunc); err != nil { + return nil, err + } + } + + if err := json.Unmarshal(response, output); err != nil { + return response, err + } + + return response, nil +} + +func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input any) ([]byte, error) { + ctx := r.Context() + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + defer cancel() + + log := observability.GetLogEntry(r).Entry + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + for i := 0; i < DefaultHTTPHookRetries; i++ { + if i == 0 { + hookLog.Debugf("invocation attempt: %d", i) + } else { + hookLog.Infof("invocation attempt: %d", i) + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + panic("Failed to make request object") + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + // By default, Go Client sets encoding to gzip, which does not carry a content length header. + req.Header.Set("Accept-Encoding", "identity") + + rsp, err := client.Do(req) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return nil, unprocessableEntityError(apierrors.ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) + + } else if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, unprocessableEntityError(apierrors.ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") + } else { + return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + defer rsp.Body.Close() + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + // Header.Get is case insensitive + contentType := rsp.Header.Get("Content-Type") + if contentType == "" { + return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header") + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error())) + } + if mediaType != "application/json" { + return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType) + } + if rsp.Body == nil { + return nil, nil + } + limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + if limitedReader.N <= 0 { + // check if the response body still has excess bytes to be read + if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 { + return nil, unprocessableEntityError(apierrors.ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit)) + } + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to switch to time duration + if retryAfterHeader != "" { + continue + } + return nil, internalServerError("Service currently unavailable due to hook") + case http.StatusBadRequest: + return nil, internalServerError("Invalid payload sent to hook") + case http.StatusUnauthorized: + return nil, internalServerError("Hook requires authorization token") + default: + return nil, internalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode) + } + } + return nil, nil +} + +// invokePostgresHook invokes the hook code. conn can be nil, in which case a new +// transaction is opened. If calling invokeHook within a transaction, always +// pass the current transaction, as pool-exhaustion deadlocks are very easy to +// trigger. +func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error { + var err error + var response []byte + + switch input.(type) { + case *hooks.SendSMSInput: + hookOutput, ok := output.(*hooks.SendSMSOutput) + if !ok { + panic("output should be *hooks.SendSMSOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.SendEmailInput: + hookOutput, ok := output.(*hooks.SendEmailOutput) + if !ok { + panic("output should be *hooks.SendEmailOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.MFAVerificationAttemptInput: + hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.MFAVerificationAttemptOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.PasswordVerificationAttemptInput: + hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.PasswordVerificationAttemptOutput") + } + + if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + + return nil + case *hooks.CustomAccessTokenInput: + hookOutput, ok := output.(*hooks.CustomAccessTokenOutput) + if !ok { + panic("output should be *hooks.CustomAccessTokenOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err) + } + + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + if err := validateTokenClaims(hookOutput.Claims); err != nil { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: err.Error(), + } + + return httpError + } + return nil + } + return nil +} + +func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + ctx := r.Context() + + logEntry := observability.GetLogEntry(r) + hookStart := time.Now() + + var response []byte + var err error + + switch { + case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"): + response, err = a.runHTTPHook(r, hookConfig, input) + case strings.HasPrefix(hookConfig.URI, "pg-functions:"): + response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output) + default: + return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI) + } + + duration := time.Since(hookStart) + + if err != nil { + logEntry.Entry.WithFields(logrus.Fields{ + "action": "run_hook", + "hook": hookConfig.URI, + "success": false, + "duration": duration.Microseconds(), + }).WithError(err).Warn("Hook errored out") + + return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err) + } + + logEntry.Entry.WithFields(logrus.Fields{ + "action": "run_hook", + "hook": hookConfig.URI, + "success": true, + "duration": duration.Microseconds(), + }).WithError(err).Info("Hook ran successfully") + + return response, nil +} + +func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go new file mode 100644 index 000000000..c78ce5f2f --- /dev/null +++ b/internal/api/hooks_test.go @@ -0,0 +1,287 @@ +package api + +import ( + "encoding/json" + "net/http" + "testing" + + "net/http/httptest" + + "github.com/pkg/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type HooksTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + TestUser *models.User +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestHooks(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &HooksTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *HooksTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + u, err := models.NewUser("123456789", "testemail@gmail.com", "securetestpassword", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + ts.TestUser = u +} + +func (ts *HooksTestSuite) TestRunHTTPHook() { + // setup mock requests for hooks + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + unsuccessfulResponse := hooks.AuthHookError{ + HTTPCode: http.StatusUnprocessableEntity, + Message: "test error", + } + + testCases := []struct { + description string + expectError bool + mockResponse hooks.AuthHookError + }{ + { + description: "Hook returns success", + expectError: false, + mockResponse: hooks.AuthHookError{}, + }, + { + description: "Hook returns error", + expectError: true, + mockResponse: unsuccessfulResponse, + }, + } + + gock.New(ts.Config.Hook.SendSMS.URI). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(hooks.SendSMSOutput{}) + + gock.New(ts.Config.Hook.SendSMS.URI). + Post("/"). + MatchType("json"). + Reply(http.StatusUnprocessableEntity). + JSON(hooks.SendSMSOutput{HookError: unsuccessfulResponse}) + + for _, tc := range testCases { + ts.Run(tc.description, func() { + req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil) + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + + if !tc.expectError { + require.NoError(ts.T(), err) + } else { + require.Error(ts.T(), err) + if body != nil { + var output hooks.SendSMSOutput + require.NoError(ts.T(), json.Unmarshal(body, &output)) + require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode) + require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message) + } + } + }) + } + require.True(ts.T(), gock.IsDone()) +} + +func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusTooManyRequests). + SetHeader("retry-after", "true").SetHeader("content-type", "application/json") + + // Simulate an additional response for the retry attempt + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(hooks.SendSMSOutput{}).SetHeader("content-type", "application/json") + + // Simulate the original HTTP request which triggered the hook + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + require.NoError(ts.T(), err) + + var output hooks.SendSMSOutput + err = json.Unmarshal(body, &output) + require.NoError(ts.T(), err, "Unmarshal should not fail") + + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} + +func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + SetHeader("content-type", "text/plain") + + req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil) + require.NoError(ts.T(), err) + + _, err = ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + require.Error(ts.T(), err, "Expected an error due to wrong content type") + require.Contains(ts.T(), err.Error(), "Invalid JSON response.") + + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called") +} + +func (ts *HooksTestSuite) TestInvokeHookIntegration() { + // We use the Send Email Hook as illustration + defer gock.OffAll() + hookFunctionSQL := ` + create or replace function invoke_test(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;` + require.NoError(ts.T(), ts.API.db.RawQuery(hookFunctionSQL).Exec()) + + testHTTPUri := "http://myauthservice.com/signup" + testHTTPSUri := "https://myauthservice.com/signup" + testPGUri := "pg-functions://postgres/auth/invoke_test" + successOutput := map[string]interface{}{} + authEndpoint := "https://app.myapp.com/otp" + gock.New(testHTTPUri). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput).SetHeader("content-type", "application/json") + + gock.New(testHTTPSUri). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput).SetHeader("content-type", "application/json") + + tests := []struct { + description string + conn *storage.Connection + request *http.Request + input any + output any + uri string + expectedError error + }{ + { + description: "HTTP endpoint success", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testHTTPUri, + }, + { + description: "HTTPS endpoint success", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testHTTPSUri, + }, + { + description: "PostgreSQL function success", + conn: ts.API.db, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testPGUri, + }, + { + description: "Unsupported protocol error", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: "ftp://example.com/path", + expectedError: errors.New("unsupported protocol: \"ftp://example.com/path\" only postgres hooks and HTTPS functions are supported at the moment"), + }, + } + + var err error + for _, tc := range tests { + // Set up hook config + ts.Config.Hook.SendEmail.Enabled = true + ts.Config.Hook.SendEmail.URI = tc.uri + require.NoError(ts.T(), ts.Config.Hook.SendEmail.PopulateExtensibilityPoint()) + + ts.Run(tc.description, func() { + err = ts.API.invokeHook(tc.conn, tc.request, tc.input, tc.output) + if tc.expectedError != nil { + require.EqualError(ts.T(), err, tc.expectedError.Error()) + } else { + require.NoError(ts.T(), err) + } + }) + + } + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} diff --git a/internal/api/identity.go b/internal/api/identity.go new file mode 100644 index 000000000..a2664d72f --- /dev/null +++ b/internal/api/identity.go @@ -0,0 +1,156 @@ +package api + +import ( + "context" + "net/http" + + "github.com/fatih/structs" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + claims := getClaims(ctx) + if claims == nil { + return internalServerError("Could not read claims") + } + + identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) + if err != nil { + return notFoundError(apierrors.ErrorCodeValidationFailed, "identity_id must be an UUID") + } + + aud := a.requestAud(ctx, r) + audienceFromClaims, _ := claims.GetAudience() + if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { + return forbiddenError(apierrors.ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") + } + + user := getUser(ctx) + if len(user.Identities) <= 1 { + return unprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") + } + var identityToBeDeleted *models.Identity + for i := range user.Identities { + identity := user.Identities[i] + if identity.ID == identityID { + identityToBeDeleted = &identity + break + } + } + if identityToBeDeleted == nil { + return unprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist") + } + + err = a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{ + "identity_id": identityToBeDeleted.ID, + "provider": identityToBeDeleted.Provider, + "provider_id": identityToBeDeleted.ProviderID, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + if terr := tx.Destroy(identityToBeDeleted); terr != nil { + return internalServerError("Database error deleting identity").WithInternalError(terr) + } + + switch identityToBeDeleted.Provider { + case "phone": + user.PhoneConfirmedAt = nil + if terr := user.SetPhone(tx, ""); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + default: + if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { + if models.IsUniqueConstraintViolatedError(terr) { + return unprocessableEntityError(apierrors.ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) + } + return internalServerError("Database error updating user email").WithInternalError(terr) + } + } + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return internalServerError("Database error updating user providers").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{}) +} + +func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + rurl, err := a.GetExternalProviderRedirectURL(w, r, user) + if err != nil { + return err + } + skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true" + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "url": rurl, + }) + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + targetUser := getTargetUser(ctx) + identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType) + if terr != nil { + if !models.IsNotFoundError(terr) { + return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr) + } + } + if identity != nil { + if identity.UserID == targetUser.ID { + return nil, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked") + } + return nil, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") + } + if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { + return nil, terr + } + + if targetUser.GetEmail() == "" { + if terr := targetUser.UpdateUserEmailFromIdentities(tx); terr != nil { + if models.IsUniqueConstraintViolatedError(terr) { + return nil, badRequestError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + return nil, terr + } + if !userData.Metadata.EmailVerified { + if terr := a.sendConfirmation(r, tx, targetUser, models.ImplicitFlow); terr != nil { + return nil, terr + } + return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) + } + if terr := targetUser.Confirm(tx); terr != nil { + return nil, terr + } + + if targetUser.IsAnonymous { + targetUser.IsAnonymous = false + if terr := tx.UpdateOnly(targetUser, "is_anonymous"); terr != nil { + return nil, terr + } + } + } + + if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + return targetUser, nil +} diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go new file mode 100644 index 000000000..b04a5d74f --- /dev/null +++ b/internal/api/identity_test.go @@ -0,0 +1,228 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type IdentityTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestIdentity(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &IdentityTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *IdentityTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) + + // Create user with 2 identities + u, err = models.NewUser("123456789", "two@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) + require.NoError(ts.T(), u.ConfirmPhone(ts.API.db)) + + i, err = models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) + + i2, err := models.NewIdentity(u, "phone", map[string]interface{}{ + "sub": u.ID.String(), + "phone": u.GetPhone(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i2)) +} + +func (ts *IdentityTestSuite) TestLinkIdentityToUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + ctx := withTargetUser(context.Background(), u) + + // link a valid identity + testValidUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: "test_subject", + }, + } + // request is just used as a placeholder in the function + r := httptest.NewRequest(http.MethodGet, "/identities", nil) + u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testValidUserData, "test") + require.NoError(ts.T(), err) + + // load associated identities for the user + ts.API.db.Load(u, "Identities") + require.Len(ts.T(), u.Identities, 2) + require.Equal(ts.T(), u.AppMetaData["provider"], "email") + require.Equal(ts.T(), u.AppMetaData["providers"], []string{"email", "test"}) + + // link an already existing identity + testExistingUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: u.ID.String(), + }, + } + u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testExistingUserData, "email") + require.ErrorIs(ts.T(), err, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked")) + require.Nil(ts.T(), u) +} + +func (ts *IdentityTestSuite) TestUnlinkIdentityError() { + ts.Config.Security.ManualLinkingEnabled = true + userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + cases := []struct { + desc string + user *models.User + identityId uuid.UUID + expectedError *HTTPError + }{ + { + desc: "User must have at least 1 identity after unlinking", + user: userWithOneIdentity, + identityId: userWithOneIdentity.Identities[0].ID, + expectedError: unprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), + }, + { + desc: "Identity doesn't exist", + user: userWithTwoIdentities, + identityId: uuid.Must(uuid.NewV4()), + expectedError: unprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + token := ts.generateAccessTokenAndSession(c.user) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedError.HTTPStatus, w.Code) + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expectedError.Message, data.Message) + }) + } +} + +func (ts *IdentityTestSuite) TestUnlinkIdentity() { + ts.Config.Security.ManualLinkingEnabled = true + + // we want to test 2 cases here: unlinking a phone identity and email identity from a user + cases := []struct { + desc string + // the provider to be unlinked + provider string + // the remaining provider that should be linked to the user + providerRemaining string + }{ + { + desc: "Unlink phone identity successfully", + provider: "phone", + providerRemaining: "email", + }, + { + desc: "Unlink email identity successfully", + provider: "email", + providerRemaining: "phone", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // teardown and reset the state of the db to prevent running into errors + ts.SetupTest() + u, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider) + require.NoError(ts.T(), err) + + token := ts.generateAccessTokenAndSession(u) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // sanity checks + u, err = models.FindUserByID(ts.API.db, u.ID) + require.NoError(ts.T(), err) + require.Len(ts.T(), u.Identities, 1) + require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining) + + // conditional checks depending on the provider that was unlinked + switch c.provider { + case "phone": + require.Equal(ts.T(), "", u.GetPhone()) + require.Nil(ts.T(), u.PhoneConfirmedAt) + case "email": + require.Equal(ts.T(), "", u.GetEmail()) + require.Nil(ts.T(), u.EmailConfirmedAt) + } + + // user still has a phone / email identity linked so it should not be unconfirmed + require.NotNil(ts.T(), u.ConfirmedAt) + }) + } + +} + +func (ts *IdentityTestSuite) generateAccessTokenAndSession(u *models.User) string { + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + return token + +} diff --git a/internal/api/invite.go b/internal/api/invite.go new file mode 100644 index 000000000..de35942b1 --- /dev/null +++ b/internal/api/invite.go @@ -0,0 +1,93 @@ +package api + +import ( + "net/http" + + "github.com/fatih/structs" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// InviteParams are the parameters the Signup endpoint accepts +type InviteParams struct { + Email string `json:"email"` + Data map[string]interface{} `json:"data"` +} + +// Invite is the endpoint for inviting a new user +func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + adminUser := getAdminUser(ctx) + params := &InviteParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil && !models.IsNotFoundError(err) { + return internalServerError("Database error finding user").WithInternalError(err) + } + + err = db.Transaction(func(tx *storage.Connection) error { + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + } else { + signupParams := SignupParams{ + Email: params.Email, + Data: params.Data, + Aud: aud, + Provider: "email", + } + + // because params above sets no password, this method + // is not computationally hard so it can be used within + // a database transaction + user, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + + user, err = a.signupNewUser(tx, user) + if err != nil { + return err + } + identity, err := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if err != nil { + return err + } + user.Identities = []models.Identity{*identity} + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserInvitedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + }); terr != nil { + return terr + } + + if err := a.sendInvite(r, tx, user); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, user) +} diff --git a/api/invite_test.go b/internal/api/invite_test.go similarity index 74% rename from api/invite_test.go rename to internal/api/invite_test.go index 8d105213c..ff0baca6a 100644 --- a/api/invite_test.go +++ b/internal/api/invite_test.go @@ -10,32 +10,30 @@ import ( "testing" "time" - "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" + jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" ) type InviteTestSuite struct { suite.Suite API *API - Config *conf.Configuration + Config *conf.GlobalConfiguration - token string - instanceID uuid.UUID + token string } func TestInvite(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() + api, config, err := setupAPIForTest() require.NoError(t, err) ts := &InviteTestSuite{ - API: api, - Config: config, - instanceID: instanceID, + API: api, + Config: config, } defer api.db.Close() @@ -51,19 +49,28 @@ func (ts *InviteTestSuite) SetupTest() { func (ts *InviteTestSuite) makeSuperAdmin(email string) string { // Cleanup existing user, if they already exist - if u, _ := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, email, ts.Config.JWT.Aud); u != nil { + if u, _ := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud); u != nil { require.NoError(ts.T(), ts.API.db.Destroy(u), "Error deleting user") } - u, err := models.NewUser(ts.instanceID, email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + u, err := models.NewUser("123456789", email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u)) u.Role = "supabase_admin" - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + var token string + + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + req := httptest.NewRequest(http.MethodPost, "/invite", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.Invite) + require.NoError(ts.T(), err, "Error generating access token") - p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { return []byte(ts.Config.JWT.Secret), nil }) @@ -94,6 +101,55 @@ func (ts *InviteTestSuite) TestInvite() { assert.Equal(ts.T(), http.StatusOK, w.Code) } +func (ts *InviteTestSuite) TestInviteAfterSignupShouldNotReturnSensitiveFields() { + // To allow us to send signup and invite request in succession + ts.Config.SMTP.MaxFrequency = 5 + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req = httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + x := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(x, req) + + require.Equal(ts.T(), http.StatusOK, x.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&data)) + // Sensitive fields + require.Equal(ts.T(), 0, len(data.Identities)) + require.Equal(ts.T(), 0, len(data.UserMetaData)) +} + func (ts *InviteTestSuite) TestInvite_WithoutAccess() { // Request body var buffer bytes.Buffer @@ -112,7 +168,7 @@ func (ts *InviteTestSuite) TestInvite_WithoutAccess() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) // 401 OK because the invite request above has no Authorization header } func (ts *InviteTestSuite) TestVerifyInvite() { @@ -126,6 +182,7 @@ func (ts *InviteTestSuite) TestVerifyInvite() { "Verify invite with password", "test@example.com", map[string]interface{}{ + "email": "test@example.com", "type": "invite", "token": "asdf", "password": "testing", @@ -136,6 +193,7 @@ func (ts *InviteTestSuite) TestVerifyInvite() { "Verify invite with no password", "test1@example.com", map[string]interface{}{ + "email": "test1@example.com", "type": "invite", "token": "asdf", }, @@ -145,17 +203,18 @@ func (ts *InviteTestSuite) TestVerifyInvite() { for _, c := range cases { ts.Run(c.desc, func() { - user, err := models.NewUser(ts.instanceID, c.email, "", ts.Config.JWT.Aud, nil) + user, err := models.NewUser("", c.email, "", ts.Config.JWT.Aud, nil) now := time.Now() user.InvitedAt = &now user.ConfirmationSentAt = &now - user.EncryptedPassword = "" - user.ConfirmationToken = c.requestBody["token"].(string) + user.EncryptedPassword = nil + user.ConfirmationToken = crypto.GenerateTokenHash(c.email, c.requestBody["token"].(string)) require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) // Find test user - _, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, c.email, ts.Config.JWT.Aud) + _, err = models.FindUserByEmailAndAudience(ts.API.db, c.email, ts.Config.JWT.Aud) require.NoError(ts.T(), err) // Request body @@ -218,7 +277,7 @@ func (ts *InviteTestSuite) TestInviteExternalGitlab() { ts.Require().Equal(http.StatusOK, w.Code) // Find test user - user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "gitlab@example.com", ts.Config.JWT.Aud) + user, err := models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) // get redirect url w/ state @@ -260,7 +319,7 @@ func (ts *InviteTestSuite) TestInviteExternalGitlab() { ts.Equal(1, userCount) // ensure user has been created with metadata - user, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "gitlab@example.com", ts.Config.JWT.Aud) + user, err = models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) ts.Require().NoError(err) ts.Equal("Gitlab Test", user.UserMetaData["full_name"]) ts.Equal("http://example.com/avatar", user.UserMetaData["avatar_url"]) @@ -310,7 +369,7 @@ func (ts *InviteTestSuite) TestInviteExternalGitlab_MismatchedEmails() { ts.Require().Equal(http.StatusOK, w.Code) // Find test user - user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "gitlab@example.com", ts.Config.JWT.Aud) + user, err := models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) // get redirect url w/ state diff --git a/internal/api/jwks.go b/internal/api/jwks.go new file mode 100644 index 000000000..b8304d2dd --- /dev/null +++ b/internal/api/jwks.go @@ -0,0 +1,61 @@ +package api + +import ( + "net/http" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + jwk "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/supabase/auth/internal/conf" +) + +type JwksResponse struct { + Keys []jwk.Key `json:"keys"` +} + +func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error { + config := a.config + resp := JwksResponse{ + Keys: []jwk.Key{}, + } + + for _, key := range config.JWT.Keys { + // don't expose hmac jwk in endpoint + if key.PublicKey == nil || key.PublicKey.KeyType() == jwa.OctetSeq { + continue + } + resp.Keys = append(resp.Keys, key.PublicKey) + } + + w.Header().Set("Cache-Control", "public, max-age=600") + return sendJSON(w, http.StatusOK, resp) +} + +func signJwt(config *conf.JWTConfiguration, claims jwt.Claims) (string, error) { + signingJwk, err := conf.GetSigningJwk(config) + if err != nil { + return "", err + } + signingMethod := conf.GetSigningAlg(signingJwk) + token := jwt.NewWithClaims(signingMethod, claims) + if token.Header == nil { + token.Header = make(map[string]interface{}) + } + + if _, ok := token.Header["kid"]; !ok { + if kid := signingJwk.KeyID(); kid != "" { + token.Header["kid"] = kid + } + } + // this serializes the aud claim to a string + jwt.MarshalSingleStringAsArray = false + signingKey, err := conf.GetSigningKey(signingJwk) + if err != nil { + return "", err + } + signed, err := token.SignedString(signingKey) + if err != nil { + return "", err + } + return signed, nil +} diff --git a/internal/api/jwks_test.go b/internal/api/jwks_test.go new file mode 100644 index 000000000..786d3438f --- /dev/null +++ b/internal/api/jwks_test.go @@ -0,0 +1,79 @@ +package api + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestJwks(t *testing.T) { + // generate RSA key pair for testing + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + rsaJwkPrivate, err := jwk.FromRaw(rsaPrivateKey) + require.NoError(t, err) + rsaJwkPublic, err := rsaJwkPrivate.PublicKey() + require.NoError(t, err) + kid := rsaJwkPublic.KeyID() + + cases := []struct { + desc string + config conf.JWTConfiguration + expectedLen int + }{ + { + desc: "hmac key should not be returned", + config: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + }, + expectedLen: 0, + }, + { + desc: "rsa public key returned", + config: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + Keys: conf.JwtKeysDecoder{ + kid: conf.JwkInfo{ + PublicKey: rsaJwkPublic, + PrivateKey: rsaJwkPrivate, + }, + }, + }, + expectedLen: 1, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + mockAPI, _, err := setupAPIForTest() + require.NoError(t, err) + mockAPI.config.JWT = c.config + + req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil) + w := httptest.NewRecorder() + mockAPI.handler.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var data map[string]interface{} + require.NoError(t, json.NewDecoder(w.Body).Decode(&data)) + require.Len(t, data["keys"], c.expectedLen) + + for _, key := range data["keys"].([]interface{}) { + bytes, err := json.Marshal(key) + require.NoError(t, err) + actualKey, err := jwk.ParseKey(bytes) + require.NoError(t, err) + require.Equal(t, c.config.Keys[kid].PublicKey, actualKey) + } + }) + } +} diff --git a/internal/api/logout.go b/internal/api/logout.go new file mode 100644 index 000000000..1ab398e22 --- /dev/null +++ b/internal/api/logout.go @@ -0,0 +1,74 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +type LogoutBehavior string + +const ( + LogoutGlobal LogoutBehavior = "global" + LogoutLocal LogoutBehavior = "local" + LogoutOthers LogoutBehavior = "others" +) + +// Logout is the endpoint for logging out a user and thereby revoking any refresh tokens +func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + scope := LogoutGlobal + + if r.URL.Query() != nil { + switch r.URL.Query().Get("scope") { + case "", "global": + scope = LogoutGlobal + + case "local": + scope = LogoutLocal + + case "others": + scope = LogoutOthers + + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + } + } + + s := getSession(ctx) + u := getUser(ctx) + + err := db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, u, models.LogoutAction, "", nil); terr != nil { + return terr + } + + if s == nil { + logrus.Infof("user has an empty session_id claim: %s", u.ID) + } else { + //exhaustive:ignore Default case is handled below. + switch scope { + case LogoutLocal: + return models.LogoutSession(tx, s.ID) + + case LogoutOthers: + return models.LogoutAllExceptMe(tx, s.ID, u.ID) + } + } + + // default mode, log out everywhere + return models.Logout(tx, u.ID) + }) + if err != nil { + return internalServerError("Error logging out user").WithInternalError(err) + } + + w.WriteHeader(http.StatusNoContent) + + return nil +} diff --git a/internal/api/logout_test.go b/internal/api/logout_test.go new file mode 100644 index 000000000..b1a0fdbb6 --- /dev/null +++ b/internal/api/logout_test.go @@ -0,0 +1,75 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type LogoutTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + token string +} + +func TestLogout(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &LogoutTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *LogoutTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // generate access token to use for logout + var t string + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + t, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + ts.token = t +} + +func (ts *LogoutTestSuite) TestLogoutSuccess() { + for _, scope := range []string{"", "global", "local", "others"} { + ts.SetupTest() + + reqURL, err := url.ParseRequestURI("http://localhost/logout") + require.NoError(ts.T(), err) + + if scope != "" { + query := reqURL.Query() + query.Set("scope", scope) + reqURL.RawQuery = query.Encode() + } + + req := httptest.NewRequest(http.MethodPost, reqURL.String(), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + } +} diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go new file mode 100644 index 000000000..fa0792e02 --- /dev/null +++ b/internal/api/magic_link.go @@ -0,0 +1,165 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// MagicLinkParams holds the parameters for a magic link request +type MagicLinkParams struct { + Email string `json:"email"` + Data map[string]interface{} `json:"data"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (p *MagicLinkParams) Validate(a *API) error { + if p.Email == "" { + return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") + } + var err error + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +// MagicLink sends a recovery email +func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + if !config.External.Email.Enabled { + return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + + if !config.External.Email.MagicLinkEnabled { + return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Login with magic link is disabled") + } + + params := &MagicLinkParams{} + jsonDecoder := json.NewDecoder(r.Body) + err := jsonDecoder.Decode(params) + if err != nil { + return badRequestError(apierrors.ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) + } + + if err := params.Validate(a); err != nil { + return err + } + + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + flowType := getFlowFromChallenge(params.CodeChallenge) + + var isNewUser bool + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + isNewUser = true + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + if user != nil { + isNewUser = !user.IsConfirmed() + } + if isNewUser { + // User either doesn't exist or hasn't completed the signup process. + // Sign them up with temporary password. + password := crypto.GeneratePassword(config.Password.RequiredCharacters, 33) + + signUpParams := &SignupParams{ + Email: params.Email, + Password: password, + Data: params.Data, + CodeChallengeMethod: params.CodeChallengeMethod, + CodeChallenge: params.CodeChallenge, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) + } + r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) + r.ContentLength = int64(len(string(newBodyContent))) + + fakeResponse := &responseStub{} + if config.Mailer.Autoconfirm { + // signups are autoconfirmed, send magic link after signup + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + newBodyContent := &SignupParams{ + Email: params.Email, + Data: params.Data, + CodeChallengeMethod: params.CodeChallengeMethod, + CodeChallenge: params.CodeChallenge, + } + metadata, err := json.Marshal(newBodyContent) + if err != nil { + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) + } + r.Body = io.NopCloser(bytes.NewReader(metadata)) + return a.MagicLink(w, r) + } + // otherwise confirmation email already contains 'magic link' + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, make(map[string]string)) + } + + if isPKCEFlow(flowType) { + if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + return a.sendMagicLink(r, tx, user, flowType) + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, make(map[string]string)) +} + +// responseStub only implement http responsewriter for ignoring +// incoming data from methods where it passed +type responseStub struct { +} + +func (rw *responseStub) Header() http.Header { + return http.Header{} +} + +func (rw *responseStub) Write(data []byte) (int, error) { + return 1, nil +} + +func (rw *responseStub) WriteHeader(statusCode int) { +} diff --git a/internal/api/mail.go b/internal/api/mail.go new file mode 100644 index 000000000..31fc309ab --- /dev/null +++ b/internal/api/mail.go @@ -0,0 +1,686 @@ +package api + +import ( + "net/http" + "regexp" + "strings" + "time" + + "github.com/supabase/auth/internal/hooks" + mail "github.com/supabase/auth/internal/mailer" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/badoux/checkmail" + "github.com/fatih/structs" + "github.com/pkg/errors" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +var ( + EmailRateLimitExceeded error = errors.New("email rate limit exceeded") +) + +type GenerateLinkParams struct { + Type string `json:"type"` + Email string `json:"email"` + NewEmail string `json:"new_email"` + Password string `json:"password"` + Data map[string]interface{} `json:"data"` + RedirectTo string `json:"redirect_to"` +} + +type GenerateLinkResponse struct { + models.User + ActionLink string `json:"action_link"` + EmailOtp string `json:"email_otp"` + HashedToken string `json:"hashed_token"` + VerificationType string `json:"verification_type"` + RedirectTo string `json:"redirect_to"` +} + +func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + mailer := a.Mailer() + adminUser := getAdminUser(ctx) + params := &GenerateLinkParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + referrer := utilities.GetReferrer(r, config) + if utilities.IsRedirectURLValid(config, params.RedirectTo) { + referrer = params.RedirectTo + } + + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + switch params.Type { + case mail.MagicLinkVerification: + params.Type = mail.SignupVerification + params.Password, err = password.Generate(64, 10, 1, false, true) + if err != nil { + // password generation must always succeed + panic(err) + } + case mail.RecoveryVerification, mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: + return notFoundError(apierrors.ErrorCodeUserNotFound, "User with this email not found") + } + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + + var url string + now := time.Now() + otp := crypto.GenerateOtp(config.Mailer.OtpLength) + + hashedToken := crypto.GenerateTokenHash(params.Email, otp) + + var signupUser *models.User + if params.Type == mail.SignupVerification && user == nil { + signupParams := &SignupParams{ + Email: params.Email, + Password: params.Password, + Data: params.Data, + Provider: "email", + Aud: aud, + } + + if err := a.validateSignupParams(ctx, signupParams); err != nil { + return err + } + + signupUser, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + switch params.Type { + case mail.MagicLinkVerification, mail.RecoveryVerification: + if terr = models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + user.RecoveryToken = hashedToken + user.RecoverySentAt = &now + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + return terr + } + + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.RecoveryToken, models.RecoveryToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating recovery token in admin") + return terr + } + case mail.InviteVerification: + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + } else { + signupParams := &SignupParams{ + Email: params.Email, + Data: params.Data, + Provider: "email", + Aud: aud, + } + + // because params above sets no password, this + // method is not computationally hard so it can + // be used within a database transaction + user, terr = signupParams.ToUserModel(false /* <- isSSOUser */) + if terr != nil { + return terr + } + + user, terr = a.signupNewUser(tx, user) + if terr != nil { + return terr + } + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if terr != nil { + return terr + } + user.Identities = []models.Identity{*identity} + } + if terr = models.NewAuditLogEntry(r, tx, adminUser, models.UserInvitedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + }); terr != nil { + return terr + } + user.ConfirmationToken = hashedToken + user.ConfirmationSentAt = &now + user.InvitedAt = &now + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating confirmation token for invite in admin") + return terr + } + case mail.SignupVerification: + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + if err := user.UpdateUserMetaData(tx, params.Data); err != nil { + return internalServerError("Database error updating user").WithInternalError(err) + } + } else { + // you should never use SignupParams with + // password here to generate a new user, use + // signupUser which is a model generated from + // SignupParams above + user, terr = a.signupNewUser(tx, signupUser) + if terr != nil { + return terr + } + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if terr != nil { + return terr + } + user.Identities = []models.Identity{*identity} + } + user.ConfirmationToken = hashedToken + user.ConfirmationSentAt = &now + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating confirmation token for signup in admin") + return terr + } + case mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: + if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") + } + params.NewEmail, terr = a.validateEmail(params.NewEmail) + if terr != nil { + return terr + } + if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { + return internalServerError("Database error checking email").WithInternalError(terr) + } else if duplicateUser != nil { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + now := time.Now() + user.EmailChangeSentAt = &now + user.EmailChange = params.NewEmail + user.EmailChangeConfirmStatus = zeroConfirmation + if params.Type == "email_change_current" { + user.EmailChangeTokenCurrent = hashedToken + } else if params.Type == "email_change_new" { + user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) + } + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + return terr + } + if user.EmailChangeTokenCurrent != "" { + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token current in admin") + return terr + } + } + if user.EmailChangeTokenNew != "" { + terr = models.CreateOneTimeToken(tx, user.ID, user.EmailChange, user.EmailChangeTokenNew, models.EmailChangeTokenNew) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token new in admin") + return terr + } + } + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) + } + + if terr != nil { + return terr + } + + externalURL := getExternalHost(ctx) + url, terr = mailer.GetEmailActionLink(user, params.Type, referrer, externalURL) + if terr != nil { + return terr + } + return nil + }) + + if err != nil { + return err + } + + resp := GenerateLinkResponse{ + User: *user, + ActionLink: url, + EmailOtp: otp, + HashedToken: hashedToken, + VerificationType: params.Type, + RedirectTo: referrer, + } + + return sendJSON(w, http.StatusOK, resp) +} + +func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error + + config := a.config + maxFrequency := config.SMTP.MaxFrequency + otpLength := config.Mailer.OtpLength + + if err = validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { + return err + } + oldToken := u.ConfirmationToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.ConfirmationToken = addFlowPrefixToToken(token, flowType) + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.SignupVerification, otp, "", u.ConfirmationToken); err != nil { + u.ConfirmationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending confirmation email").WithInternalError(err) + } + u.ConfirmationSentAt = &now + if err := tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error updating user for confirmation")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error creating confirmation token")) + } + + return nil +} + +func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User) error { + config := a.config + otpLength := config.Mailer.OtpLength + var err error + oldToken := u.ConfirmationToken + otp := crypto.GenerateOtp(otpLength) + + u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken); err != nil { + u.ConfirmationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending invite email").WithInternalError(err) + } + u.InvitedAt = &now + u.ConfirmationSentAt = &now + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error updating user for invite")) + } + + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) + if err != nil { + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error creating confirmation token for invite")) + } + + return nil +} + +func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + config := a.config + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + oldToken := u.RecoveryToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.RecoveryToken = addFlowPrefixToToken(token, flowType) + now := time.Now() + if err := a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { + u.RecoveryToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending recovery email").WithInternalError(err) + } + u.RecoverySentAt = &now + + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + } + + return nil +} + +func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u *models.User) error { + config := a.config + maxFrequency := config.SMTP.MaxFrequency + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.ReauthenticationSentAt, maxFrequency); err != nil { + return err + } + + oldToken := u.ReauthenticationToken + otp := crypto.GenerateOtp(otpLength) + + u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) + now := time.Now() + + if err := a.sendEmail(r, tx, u, mail.ReauthenticationVerification, otp, "", u.ReauthenticationToken); err != nil { + u.ReauthenticationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending reauthentication email").WithInternalError(err) + } + u.ReauthenticationSentAt = &now + if err := tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error updating user for reauthentication")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error creating reauthentication token")) + } + + return nil +} + +func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error + config := a.config + otpLength := config.Mailer.OtpLength + + // since Magic Link is just a recovery with a different template and behaviour + // around new users we will reuse the recovery db timer to prevent potential abuse + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + oldToken := u.RecoveryToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.RecoveryToken = addFlowPrefixToToken(token, flowType) + + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.MagicLinkVerification, otp, "", u.RecoveryToken); err != nil { + u.RecoveryToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending magic link email").WithInternalError(err) + } + u.RecoverySentAt = &now + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + } + + return nil +} + +// sendEmailChange sends out an email change token to the new email. +func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models.User, email string, flowType models.FlowType) error { + config := a.config + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + otpNew := crypto.GenerateOtp(otpLength) + + u.EmailChange = email + token := crypto.GenerateTokenHash(u.EmailChange, otpNew) + u.EmailChangeTokenNew = addFlowPrefixToToken(token, flowType) + + otpCurrent := "" + if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + otpCurrent = crypto.GenerateOtp(otpLength) + + currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) + u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) + } + + u.EmailChangeConfirmStatus = zeroConfirmation + now := time.Now() + + if err := a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending email change email").WithInternalError(err) + } + + u.EmailChangeSentAt = &now + if err := tx.UpdateOnly( + u, + "email_change_token_current", + "email_change_token_new", + "email_change", + "email_change_sent_at", + "email_change_confirm_status", + ); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) + } + + if u.EmailChangeTokenCurrent != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) + } + } + + if u.EmailChangeTokenNew != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) + } + } + + return nil +} + +func (a *API) validateEmail(email string) (string, error) { + if email == "" { + return "", badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required") + } + if len(email) > 255 { + return "", badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long") + } + if err := checkmail.ValidateFormat(email); err != nil { + return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) + } + + return strings.ToLower(email), nil +} + +func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { + if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { + return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) + } + return nil +} + +var emailLabelPattern = regexp.MustCompile("[+][^@]+@") + +func (a *API) checkEmailAddressAuthorization(email string) bool { + if len(a.config.External.Email.AuthorizedAddresses) > 0 { + // allow labelled emails when authorization rules are in place + normalized := emailLabelPattern.ReplaceAllString(email, "@") + + for _, authorizedAddress := range a.config.External.Email.AuthorizedAddresses { + if strings.EqualFold(normalized, authorizedAddress) { + return true + } + } + + return false + } + + return true +} + +func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, emailActionType, otp, otpNew, tokenHashWithPrefix string) error { + ctx := r.Context() + config := a.config + referrerURL := utilities.GetReferrer(r, config) + externalURL := getExternalHost(ctx) + + if emailActionType != mail.EmailChangeVerification { + if u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { + return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + } + } else { + // first check that the user can update their address to the + // new one in u.EmailChange + if u.EmailChange != "" && !a.checkEmailAddressAuthorization(u.EmailChange) { + return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.EmailChange) + } + + // if secure email change is enabled, check that the user + // account (which could have been created before the authorized + // address authorization restriction was enabled) can even + // receive the confirmation message to the existing address + if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { + return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + } + } + + // if the number of events is set to zero, we immediately apply rate limits. + if config.RateLimitEmailSent.Events == 0 { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + + // TODO(km): Deprecate this behaviour - rate limits should still be applied to autoconfirm + if !config.Mailer.Autoconfirm { + // apply rate limiting before the email is sent out + if ok := a.limiterOpts.Email.Allow(); !ok { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + } + + if config.Hook.SendEmail.Enabled { + // When secure email change is disabled, we place the token for the new email on emailData.Token + if emailActionType == mail.EmailChangeVerification && !config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + otp = otpNew + } + + emailData := mail.EmailData{ + Token: otp, + EmailActionType: emailActionType, + RedirectTo: referrerURL, + SiteURL: externalURL.String(), + TokenHash: tokenHashWithPrefix, + } + if emailActionType == mail.EmailChangeVerification && config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + emailData.TokenNew = otpNew + emailData.TokenHashNew = u.EmailChangeTokenCurrent + } + input := hooks.SendEmailInput{ + User: u, + EmailData: emailData, + } + output := hooks.SendEmailOutput{} + return a.invokeHook(tx, r, &input, &output) + } + + mr := a.Mailer() + var err error + switch emailActionType { + case mail.SignupVerification: + err = mr.ConfirmationMail(r, u, otp, referrerURL, externalURL) + case mail.MagicLinkVerification: + err = mr.MagicLinkMail(r, u, otp, referrerURL, externalURL) + case mail.ReauthenticationVerification: + err = mr.ReauthenticateMail(r, u, otp) + case mail.RecoveryVerification: + err = mr.RecoveryMail(r, u, otp, referrerURL, externalURL) + case mail.InviteVerification: + err = mr.InviteMail(r, u, otp, referrerURL, externalURL) + case mail.EmailChangeVerification: + err = mr.EmailChangeMail(r, u, otpNew, otp, referrerURL, externalURL) + default: + err = errors.New("invalid email action type") + } + + switch { + case errors.Is(err, mail.ErrInvalidEmailAddress), + errors.Is(err, mail.ErrInvalidEmailFormat), + errors.Is(err, mail.ErrInvalidEmailDNS): + return badRequestError( + apierrors.ErrorCodeEmailAddressInvalid, + "Email address %q is invalid", + u.GetEmail()) + default: + return err + } +} diff --git a/internal/api/mail_test.go b/internal/api/mail_test.go new file mode 100644 index 000000000..c13c18e69 --- /dev/null +++ b/internal/api/mail_test.go @@ -0,0 +1,257 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gobwas/glob" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type MailTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestMail(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &MailTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *MailTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + ts.Config.Mailer.SecureEmailChangeEnabled = true + + // Create User + u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating new user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user") +} + +func (ts *MailTestSuite) TestValidateEmail() { + cases := []struct { + desc string + email string + expectedEmail string + expectedError error + }{ + { + desc: "valid email", + email: "test@example.com", + expectedEmail: "test@example.com", + expectedError: nil, + }, + { + desc: "email should be normalized", + email: "TEST@EXAMPLE.COM", + expectedEmail: "test@example.com", + expectedError: nil, + }, + { + desc: "empty email should return error", + email: "", + expectedEmail: "", + expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required"), + }, + { + desc: "email length exceeds 255 characters", + // email has 256 characters + email: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest@example.com", + expectedEmail: "", + expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + email, err := ts.API.validateEmail(c.email) + require.Equal(ts.T(), c.expectedError, err) + require.Equal(ts.T(), c.expectedEmail, email) + }) + } +} + +func (ts *MailTestSuite) TestGenerateLink() { + // create admin jwt + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + + ts.setURIAllowListMap("http://localhost:8000/**") + // create test cases + cases := []struct { + Desc string + Body GenerateLinkParams + ExpectedCode int + ExpectedResponse map[string]interface{} + }{ + { + Desc: "Generate signup link for new user", + Body: GenerateLinkParams{ + Email: "new_user@example.com", + Password: "secret123", + Type: "signup", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate signup link for existing user", + Body: GenerateLinkParams{ + Email: "test@example.com", + Password: "secret123", + Type: "signup", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate signup link with custom redirect url", + Body: GenerateLinkParams{ + Email: "test@example.com", + Password: "secret123", + Type: "signup", + RedirectTo: "http://localhost:8000/welcome", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": "http://localhost:8000/welcome", + }, + }, + { + Desc: "Generate magic link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "magiclink", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate invite link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "invite", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate recovery link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "recovery", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate email change link", + Body: GenerateLinkParams{ + Email: "test@example.com", + NewEmail: "new@example.com", + Type: "email_change_current", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate email change link", + Body: GenerateLinkParams{ + Email: "test@example.com", + NewEmail: "new@example.com", + Type: "email_change_new", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + } + + customDomainUrl, err := url.ParseRequestURI("https://example.gotrue.com") + require.NoError(ts.T(), err) + + originalHosts := ts.API.config.Mailer.ExternalHosts + ts.API.config.Mailer.ExternalHosts = []string{ + "example.gotrue.com", + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.Body)) + req := httptest.NewRequest(http.MethodPost, customDomainUrl.String()+"/admin/generate_link", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.ExpectedCode, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Contains(ts.T(), data, "action_link") + require.Contains(ts.T(), data, "email_otp") + require.Contains(ts.T(), data, "hashed_token") + require.Contains(ts.T(), data, "redirect_to") + require.Equal(ts.T(), c.Body.Type, data["verification_type"]) + + // check if redirect_to is correct + require.Equal(ts.T(), c.ExpectedResponse["redirect_to"], data["redirect_to"]) + + // check if hashed_token matches hash function of email and the raw otp + require.Equal(ts.T(), crypto.GenerateTokenHash(c.Body.Email, data["email_otp"].(string)), data["hashed_token"]) + + // check if the host used in the email link matches the initial request host + u, err := url.ParseRequestURI(data["action_link"].(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), req.Host, u.Host) + }) + } + + ts.API.config.Mailer.ExternalHosts = originalHosts +} + +func (ts *MailTestSuite) setURIAllowListMap(uris ...string) { + for _, uri := range uris { + g := glob.MustCompile(uri, '.', '/') + ts.Config.URIAllowListMap[uri] = g + } +} diff --git a/internal/api/mfa.go b/internal/api/mfa.go new file mode 100644 index 000000000..2f79ccec2 --- /dev/null +++ b/internal/api/mfa.go @@ -0,0 +1,1030 @@ +package api + +import ( + "bytes" + "crypto/subtle" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/aaronarduino/goqrsvg" + svg "github.com/ajstarks/svgo" + "github.com/boombuler/barcode/qr" + wbnprotocol "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const DefaultQRSize = 3 + +type EnrollFactorParams struct { + FriendlyName string `json:"friendly_name"` + FactorType string `json:"factor_type"` + Issuer string `json:"issuer"` + Phone string `json:"phone"` +} + +type TOTPObject struct { + QRCode string `json:"qr_code,omitempty"` + Secret string `json:"secret,omitempty"` + URI string `json:"uri,omitempty"` +} + +type EnrollFactorResponse struct { + ID uuid.UUID `json:"id"` + Type string `json:"type"` + FriendlyName string `json:"friendly_name"` + TOTP *TOTPObject `json:"totp,omitempty"` + Phone string `json:"phone,omitempty"` +} + +type ChallengeFactorParams struct { + Channel string `json:"channel"` + WebAuthn *WebAuthnParams `json:"web_authn,omitempty"` +} + +type VerifyFactorParams struct { + ChallengeID uuid.UUID `json:"challenge_id"` + Code string `json:"code"` + WebAuthn *WebAuthnParams `json:"web_authn,omitempty"` +} + +type ChallengeFactorResponse struct { + ID uuid.UUID `json:"id"` + Type string `json:"type"` + ExpiresAt int64 `json:"expires_at,omitempty"` + CredentialRequestOptions *wbnprotocol.CredentialAssertion `json:"credential_request_options,omitempty"` + CredentialCreationOptions *wbnprotocol.CredentialCreation `json:"credential_creation_options,omitempty"` +} + +type UnenrollFactorResponse struct { + ID uuid.UUID `json:"id"` +} + +type WebAuthnParams struct { + RPID string `json:"rp_id,omitempty"` + // Can encode multiple origins as comma separated values like: "origin1,origin2" + RPOrigins string `json:"rp_origins,omitempty"` + AssertionResponse json.RawMessage `json:"assertion_response,omitempty"` + CreationResponse json.RawMessage `json:"creation_response,omitempty"` +} + +func (w *WebAuthnParams) GetRPOrigins() []string { + if w.RPOrigins == "" { + return nil + } + return strings.Split(w.RPOrigins, ",") +} + +func (w *WebAuthnParams) ToConfig() (*webauthn.WebAuthn, error) { + if w.RPID == "" { + return nil, fmt.Errorf("webAuthn RP ID cannot be empty") + } + + origins := w.GetRPOrigins() + if len(origins) == 0 { + return nil, fmt.Errorf("webAuthn RP Origins cannot be empty") + } + + var validOrigins []string + var invalidOrigins []string + + for _, origin := range origins { + parsedURL, err := url.Parse(origin) + if err != nil || (parsedURL.Scheme != "https" && !(parsedURL.Scheme == "http" && parsedURL.Hostname() == "localhost")) || parsedURL.Host == "" { + invalidOrigins = append(invalidOrigins, origin) + } else { + validOrigins = append(validOrigins, origin) + } + } + + if len(invalidOrigins) > 0 { + return nil, fmt.Errorf("invalid RP origins: %s", strings.Join(invalidOrigins, ", ")) + } + + wconfig := &webauthn.Config{ + // DisplayName is optional in spec but required to be non-empty in libary, we use the RPID as a placeholder. + RPDisplayName: w.RPID, + RPID: w.RPID, + RPOrigins: validOrigins, + } + + return webauthn.New(wconfig) +} + +const ( + QRCodeGenerationErrorMessage = "Error generating QR Code" +) + +func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { + if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + return err + } + if err := db.Load(user, "Factors"); err != nil { + return err + } + factorCount := len(user.Factors) + numVerifiedFactors := 0 + + for _, factor := range user.Factors { + if factor.FriendlyName == newFactorName { + return unprocessableEntityError( + apierrors.ErrorCodeMFAFactorNameConflict, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), + ) + } + if factor.IsVerified() { + numVerifiedFactors++ + } + } + + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return unprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { + return unprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() { + return forbiddenError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") + } + + return nil +} + +func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + if params.Phone == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Phone number required to enroll Phone factor") + } + + phone, err := validatePhone(params.Phone) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + + var factorsToDelete []models.Factor + for _, factor := range user.Factors { + if factor.IsPhoneFactor() && factor.Phone.String() == phone { + if factor.IsVerified() { + return unprocessableEntityError( + apierrors.ErrorCodeMFAVerifiedFactorExists, + "A verified phone factor already exists, unenroll the existing factor to continue", + ) + } else if factor.IsUnverified() { + factorsToDelete = append(factorsToDelete, factor) + } + } + } + + if err := db.Destroy(&factorsToDelete); err != nil { + return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) + } + + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err + } + + factor := models.NewPhoneFactor(user, phone, params.FriendlyName) + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.Phone, + FriendlyName: factor.FriendlyName, + Phone: params.Phone, + }) +} + +func (a *API) enrollWebAuthnFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err + } + + factor := models.NewWebAuthnFactor(user, params.FriendlyName) + err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.WebAuthn, + FriendlyName: factor.FriendlyName, + }) +} + +func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + db := a.db.WithContext(ctx) + config := a.config + session := getSession(ctx) + issuer := "" + if params.Issuer == "" { + u, err := url.ParseRequestURI(config.SiteURL) + if err != nil { + return internalServerError("site url is improperly formatted") + } + issuer = u.Host + } else { + issuer = params.Issuer + } + + if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil { + return err + } + + var factor *models.Factor + var buf bytes.Buffer + var key *otp.Key + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: user.GetEmail(), + }) + if err != nil { + return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + } + + svgData := svg.New(&buf) + qrCode, _ := qr.Encode(key.String(), qr.H, qr.Auto) + qs := goqrsvg.NewQrSVG(qrCode, DefaultQRSize) + qs.StartQrSVG(svgData) + if err = qs.WriteQrSVG(svgData); err != nil { + return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + } + svgData.End() + + factor = models.NewTOTPFactor(user, params.FriendlyName) + if err := factor.SetSecret(key.Secret(), config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.TOTP, + FriendlyName: factor.FriendlyName, + TOTP: &TOTPObject{ + // See: https://css-tricks.com/probably-dont-base64-svg/ + QRCode: buf.String(), + Secret: key.Secret(), + URI: key.URL(), + }, + }) +} + +func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + config := a.config + + if session == nil || user == nil { + return internalServerError("A valid session and a registered user are required to enroll a factor") + } + params := &EnrollFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + switch params.FactorType { + case models.Phone: + if !config.MFA.Phone.EnrollEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") + } + return a.enrollPhoneFactor(w, r, params) + case models.TOTP: + if !config.MFA.TOTP.EnrollEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") + } + return a.enrollTOTPFactor(w, r, params) + case models.WebAuthn: + if !config.MFA.WebAuthn.EnrollEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA enroll is disabled for WebAuthn") + } + return a.enrollWebAuthnFactor(w, r, params) + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + params := &ChallengeFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + channel := params.Channel + if channel == "" { + channel = sms_provider.SMSProvider + } + if !sms_provider.IsValidMessageChannel(channel, config) { + return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + } + + if factor.IsPhoneFactor() && factor.LastChallengedAt != nil { + if !factor.LastChallengedAt.Add(config.MFA.Phone.MaxFrequency).Before(time.Now()) { + return tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) + } + } + + otp := crypto.GenerateOtp(config.MFA.Phone.OtpLength) + + challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) + if err != nil { + return internalServerError("error creating SMS Challenge") + } + + message, err := generateSMSFromTemplate(config.MFA.Phone.SMSTemplate, otp) + if err != nil { + return internalServerError("error generating sms template").WithInternalError(err) + } + + if config.Hook.SendSMS.Enabled { + input := hooks.SendSMSInput{ + User: user, + SMS: hooks.SMS{ + OTP: otp, + SMSType: "mfa", + }, + } + output := hooks.SendSMSOutput{} + err := a.invokeHook(a.db, r, &input, &output) + if err != nil { + return internalServerError("error invoking hook") + } + } else { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return internalServerError("Failed to get SMS provider").WithInternalError(err) + } + // We omit messageID for now, can consider reinstating if there are requests. + if _, err = smsProvider.SendMessage(factor.Phone.String(), message, channel, otp); err != nil { + return internalServerError("error sending message").WithInternalError(err) + } + } + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := factor.WriteChallengeToDatabase(tx, challenge); terr != nil { + return terr + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.CreateChallengeAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + }); terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + return sendJSON(w, http.StatusOK, &ChallengeFactorResponse{ + ID: challenge.ID, + Type: factor.FactorType, + ExpiresAt: challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix(), + }) +} + +func (a *API) challengeTOTPFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + + challenge := factor.CreateChallenge(ipAddress) + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := factor.WriteChallengeToDatabase(tx, challenge); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.CreateChallengeAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + }); terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, &ChallengeFactorResponse{ + ID: challenge.ID, + Type: factor.FactorType, + ExpiresAt: challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix(), + }) +} + +func (a *API) challengeWebAuthnFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + + params := &ChallengeFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.WebAuthn == nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "web_authn config required") + } + webAuthn, err := params.WebAuthn.ToConfig() + if err != nil { + return err + } + var response *ChallengeFactorResponse + var ws *models.WebAuthnSessionData + var challenge *models.Challenge + if factor.IsUnverified() { + options, session, err := webAuthn.BeginRegistration(user) + if err != nil { + return internalServerError("Failed to generate WebAuthn registration data").WithInternalError(err) + } + ws = &models.WebAuthnSessionData{ + SessionData: session, + } + challenge = ws.ToChallenge(factor.ID, ipAddress) + + response = &ChallengeFactorResponse{ + CredentialCreationOptions: options, + Type: factor.FactorType, + ID: challenge.ID, + } + + } else if factor.IsVerified() { + options, session, err := webAuthn.BeginLogin(user) + if err != nil { + return err + } + ws = &models.WebAuthnSessionData{ + SessionData: session, + } + challenge = ws.ToChallenge(factor.ID, ipAddress) + response = &ChallengeFactorResponse{ + CredentialRequestOptions: options, + Type: factor.FactorType, + ID: challenge.ID, + } + + } + + if err := factor.WriteChallengeToDatabase(db, challenge); err != nil { + return err + } + response.ExpiresAt = challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix() + + return sendJSON(w, http.StatusOK, response) + +} + +func (a *API) validateChallenge(r *http.Request, db *storage.Connection, factor *models.Factor, challengeID uuid.UUID) (*models.Challenge, error) { + config := a.config + currentIP := utilities.GetIPAddress(r) + + challenge, err := factor.FindChallengeByID(db, challengeID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, unprocessableEntityError(apierrors.ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") + } + return nil, internalServerError("Database error finding Challenge").WithInternalError(err) + } + + if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { + return nil, unprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch.") + } + + if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { + if err := db.Destroy(challenge); err != nil { + return nil, internalServerError("Database error deleting challenge").WithInternalError(err) + } + return nil, unprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + } + + return challenge, nil +} + +func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + factor := getFactor(ctx) + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + } + return a.challengePhoneFactor(w, r) + + case models.TOTP: + if !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + } + return a.challengeTOTPFactor(w, r) + case models.WebAuthn: + if !config.MFA.WebAuthn.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnVerifyDisabled, "MFA verification is disabled for WebAuthn") + } + return a.challengeWebAuthnFactor(w, r) + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + var err error + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + config := a.config + db := a.db.WithContext(ctx) + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + + secret, shouldReEncrypt, err := factor.GetSecret(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + } + + valid, verr := totp.ValidateCustom(params.Code, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + + if config.Hook.MFAVerificationAttempt.Enabled { + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + Valid: valid, + } + + output := hooks.MFAVerificationAttemptOutput{} + err := a.invokeHook(nil, r, &input, &output) + if err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if err := models.Logout(db, user.ID); err != nil { + return err + } + + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } + + return forbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) + } + } + if !valid { + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + if err := factor.SetSecret(secret, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + if err := db.UpdateOnly(factor, "secret"); err != nil { + return err + } + } + return unprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) + } + + var token *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + if terr = challenge.Verify(tx); terr != nil { + return terr + } + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + } + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + es, terr := crypto.NewEncryptedString(factor.ID.String(), []byte(secret), config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) + if terr != nil { + return terr + } + + factor.Secret = es.String() + if terr := tx.UpdateOnly(factor, "secret"); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.TOTPSignIn, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update sessions. %s", terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { + return internalServerError("Error removing unverified factors. %s", terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) + +} + +func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + ctx := r.Context() + config := a.config + user := getUser(ctx) + factor := getFactor(ctx) + db := a.db.WithContext(ctx) + currentIP := utilities.GetIPAddress(r) + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + + if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { + return unprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") + } + + if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { + if err := db.Destroy(challenge); err != nil { + return internalServerError("Database error deleting challenge").WithInternalError(err) + } + return unprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + } + var valid bool + var otpCode string + var shouldReEncrypt bool + if config.Sms.IsTwilioVerifyProvider() { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return internalServerError("Failed to get SMS provider").WithInternalError(err) + } + if err := smsProvider.VerifyOTP(factor.Phone.String(), params.Code); err != nil { + return forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + valid = true + } else { + otpCode, shouldReEncrypt, err = challenge.GetOtpCode(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + } + valid = subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 + } + if config.Hook.MFAVerificationAttempt.Enabled { + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + FactorType: factor.FactorType, + Valid: valid, + } + + output := hooks.MFAVerificationAttemptOutput{} + err := a.invokeHook(nil, r, &input, &output) + if err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if err := models.Logout(db, user.ID); err != nil { + return err + } + + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } + + return forbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) + } + } + if !valid { + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + if err := challenge.SetOtpCode(otpCode, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + if err := db.UpdateOnly(challenge, "otp_code"); err != nil { + return err + } + } + return unprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") + } + + var token *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + if terr = challenge.Verify(tx); terr != nil { + return terr + } + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAPhone, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update sessions. %s", terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { + return internalServerError("Error removing unverified factors. %s", terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + db := a.db.WithContext(ctx) + + var webAuthn *webauthn.WebAuthn + var credential *webauthn.Credential + var err error + + switch { + case params.WebAuthn == nil: + return badRequestError(apierrors.ErrorCodeValidationFailed, "WebAuthn config required") + case factor.IsVerified() && params.WebAuthn.AssertionResponse == nil: + return badRequestError(apierrors.ErrorCodeValidationFailed, "creation_response required to login") + case factor.IsUnverified() && params.WebAuthn.CreationResponse == nil: + return badRequestError(apierrors.ErrorCodeValidationFailed, "assertion_response required to login") + default: + webAuthn, err = params.WebAuthn.ToConfig() + if err != nil { + return err + } + } + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + webAuthnSession := *challenge.WebAuthnSessionData.SessionData + // Once the challenge is validated, we consume the challenge + if err := db.Destroy(challenge); err != nil { + return internalServerError("Database error deleting challenge").WithInternalError(err) + } + + if factor.IsUnverified() { + parsedResponse, err := wbnprotocol.ParseCredentialCreationResponseBody(bytes.NewReader(params.WebAuthn.CreationResponse)) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_creation_response") + } + credential, err = webAuthn.CreateCredential(user, webAuthnSession, parsedResponse) + if err != nil { + return err + } + + } else if factor.IsVerified() { + parsedResponse, err := wbnprotocol.ParseCredentialRequestResponseBody(bytes.NewReader(params.WebAuthn.AssertionResponse)) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_request_response") + } + credential, err = webAuthn.ValidateLogin(user, webAuthnSession, parsedResponse) + if err != nil { + return internalServerError("Failed to validate WebAuthn MFA response").WithInternalError(err) + } + } + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + // Challenge verification not needed as the challenge is destroyed on use + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + if terr = factor.SaveWebAuthnCredential(tx, credential); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAWebAuthn, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update session").WithInternalError(terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, models.WebAuthn); terr != nil { + return internalServerError("Failed to remove unverified MFA WebAuthn factors").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + factor := getFactor(ctx) + config := a.config + + params := &VerifyFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.Code == "" && factor.FactorType != models.WebAuthn { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Code needs to be non-empty") + } + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + } + + return a.verifyPhoneFactor(w, r, params) + case models.TOTP: + if !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + } + return a.verifyTOTPFactor(w, r, params) + case models.WebAuthn: + if !config.MFA.WebAuthn.VerifyEnabled { + return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA verification is disabled for WebAuthn") + } + return a.verifyWebAuthnFactor(w, r, params) + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { + var err error + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + + if factor == nil || session == nil || user == nil { + return internalServerError("A valid session and factor are required to unenroll a factor") + } + + if factor.IsVerified() && !session.IsAAL2() { + return unprocessableEntityError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr := tx.Destroy(factor); terr != nil { + return terr + } + if terr = models.NewAuditLogEntry(r, tx, user, models.UnenrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + "session_id": session.ID, + }); terr != nil { + return terr + } + if terr = factor.DowngradeSessionsToAAL1(tx); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, &UnenrollFactorResponse{ + ID: factor.ID, + }) +} diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go new file mode 100644 index 000000000..6d9a5a3c3 --- /dev/null +++ b/internal/api/mfa_test.go @@ -0,0 +1,1012 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gofrs/uuid" + + "github.com/pquerna/otp" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/utilities" + + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type MFATestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + TestDomain string + TestEmail string + TestOTPKey *otp.Key + TestPassword string + TestUser *models.User + TestSession *models.Session + TestSecondarySession *models.Session +} + +func TestMFA(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &MFATestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *MFATestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + ts.TestEmail = "test@example.com" + ts.TestPassword = "password" + // Create user + u, err := models.NewUser("123456789", ts.TestEmail, ts.TestPassword, ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + // Create Factor + f := models.NewTOTPFactor(u, "test_factor") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + // Create corresponding session + s, err := models.NewSession(u.ID, &f.ID) + require.NoError(ts.T(), err, "Error creating test session") + require.NoError(ts.T(), ts.API.db.Create(s), "Error saving test session") + + u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud) + ts.Require().NoError(err) + + ts.TestUser = u + ts.TestSession = s + + secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID) + require.NoError(ts.T(), err, "Error creating test session") + require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session") + + ts.TestSecondarySession = secondarySession + + // Generate TOTP related settings + testDomain := strings.Split(ts.TestEmail, "@")[1] + ts.TestDomain = testDomain + + // By default MFA Phone is disabled + ts.Config.MFA.Phone.EnrollEnabled = true + ts.Config.MFA.Phone.VerifyEnabled = true + + ts.Config.MFA.WebAuthn.EnrollEnabled = true + ts.Config.MFA.WebAuthn.VerifyEnabled = true + + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: ts.TestDomain, + AccountName: ts.TestEmail, + }) + require.NoError(ts.T(), err) + ts.TestOTPKey = key + +} + +func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string { + // Not an actual path. Dummy request to simulate a signup request that we can use in generateAccessToken + req := httptest.NewRequest(http.MethodPost, "/factors", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.TOTPSignIn) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *MFATestSuite) TestEnrollFactor() { + testFriendlyName := "bob" + alternativeFriendlyName := "john" + + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + friendlyName string + factorType string + issuer string + phone string + expectedCode int + }{ + { + desc: "TOTP: No issuer", + friendlyName: alternativeFriendlyName, + factorType: models.TOTP, + issuer: "", + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "Invalid factor type", + friendlyName: testFriendlyName, + factorType: "invalid_factor", + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusBadRequest, + }, + { + desc: "TOTP: Factor has friendly name", + friendlyName: testFriendlyName, + factorType: models.TOTP, + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "TOTP: Enrolling without friendly name", + friendlyName: "", + factorType: models.TOTP, + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "Phone: Enroll with friendly name", + friendlyName: "phone_factor", + factorType: models.Phone, + phone: "+12345677889", + expectedCode: http.StatusOK, + }, + { + desc: "Phone: Enroll with invalid phone number", + friendlyName: "phone_factor", + factorType: models.Phone, + phone: "+1", + expectedCode: http.StatusBadRequest, + }, + { + desc: "Phone: Enroll without phone number should return error", + friendlyName: "phone_factor_fail", + factorType: models.Phone, + phone: "", + expectedCode: http.StatusBadRequest, + }, + { + desc: "WebAuthn: Enroll with friendly name", + friendlyName: "webauthn_factor", + factorType: models.WebAuthn, + expectedCode: http.StatusOK, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.phone, c.expectedCode) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + if c.expectedCode == http.StatusOK { + addedFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.False(ts.T(), addedFactor.IsVerified()) + + if c.friendlyName != "" { + require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName) + } + + if c.factorType == models.TOTP { + qrCode := enrollResp.TOTP.QRCode + hasSVGStartAndEnd := strings.Contains(qrCode, "") + require.True(ts.T(), hasSVGStartAndEnd) + require.Equal(ts.T(), c.friendlyName, enrollResp.FriendlyName) + } + } + + }) + } +} + +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactor() { + testPhoneNumber := "+12345677889" + altPhoneNumber := "+987412444444" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + earlierFactorName string + laterFactorName string + phone string + secondPhone string + expectedCode int + expectedNumberOfFactors int + }{ + { + desc: "Phone: Only the latest factor should persist when enrolling two unverified phone factors with the same number", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: testPhoneNumber, + expectedNumberOfFactors: 1, + }, + + { + desc: "Phone: Both factors should persist when enrolling two different unverified numbers", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: altPhoneNumber, + expectedNumberOfFactors: 2, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Delete all test factors to start from clean slate + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + _ = performEnrollFlow(ts, token, c.earlierFactorName, models.Phone, ts.TestDomain, c.phone, http.StatusOK) + + w := performEnrollFlow(ts, token, c.laterFactorName, models.Phone, ts.TestDomain, c.secondPhone, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + laterFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.False(ts.T(), laterFactor.IsVerified()) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), c.expectedNumberOfFactors) + + }) + } +} + +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactorWithVerified() { + testPhoneNumber := "+12345677889" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + ts.Run("Phone: Enrolling a factor with the same number as an existing verified phone factor should result in an error", func() { + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + + // Setup verified factor + w := performEnrollFlow(ts, token, friendlyName, models.Phone, ts.TestDomain, testPhoneNumber, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + firstFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.NoError(ts.T(), firstFactor.UpdateStatus(ts.API.db, models.FactorStateVerified)) + + expectedStatusCode := http.StatusUnprocessableEntity + _ = performEnrollFlow(ts, token, altFriendlyName, models.Phone, ts.TestDomain, testPhoneNumber, expectedStatusCode) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), 1) + }) +} + +func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { + friendlyName := "mary" + issuer := "https://issuer.com" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, "", http.StatusOK) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, "", http.StatusUnprocessableEntity) + + var errorResponse HTTPError + err := json.NewDecoder(response.Body).Decode(&errorResponse) + require.NoError(ts.T(), err) + + require.Contains(ts.T(), errorResponse.ErrorCode, apierrors.ErrorCodeMFAFactorNameConflict) +} + +func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() { + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var w *httptest.ResponseRecorder + var buffer bytes.Buffer + token := accessTokenResp.Token + // Update Password to new password + newPassword := "newpass" + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": newPassword, + })) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Logout + reqURL := "http://localhost/logout" + req = httptest.NewRequest(http.MethodPost, reqURL, nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + // Get AAL1 token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": ts.TestEmail, + "password": newPassword, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + session1 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session1)) + + // Update Password again, this should fail + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": ts.TestPassword, + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session1.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusUnauthorized, w.Code) + +} + +func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { + // All factors are deleted when a subsequent enroll is made + ts.API.config.MFA.FactorExpiryDuration = 0 * time.Second + // Verified factor should not be deleted (Factor 1) + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + numFactors := 5 + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var w *httptest.ResponseRecorder + token := accessTokenResp.Token + for i := 0; i < numFactors; i++ { + w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK) + } + + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + // Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2) + _ = performChallengeFlow(ts, enrollResp.ID, token) + + // Enroll another Factor (Factor 3) + _ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK) + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), 3, len(ts.TestUser.Factors)) +} + +func (ts *MFATestSuite) TestChallengeTOTPFactor() { + // Test Factor is a TOTP Factor + f := ts.TestUser.Factors[0] + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := performChallengeFlow(ts, f.ID, token) + challengeResp := ChallengeFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + + require.Equal(ts.T(), http.StatusOK, w.Code) + require.Equal(ts.T(), challengeResp.Type, models.TOTP) + +} + +func (ts *MFATestSuite) TestChallengeSMSFactor() { + // Challenge should still work with phone provider disabled + ts.Config.External.Phone.Enabled = false + ts.Config.Hook.SendSMS.Enabled = true + ts.Config.Hook.SendSMS.URI = "pg-functions://postgres/auth/send_sms_mfa_mock" + + ts.Config.MFA.Phone.MaxFrequency = 0 * time.Second + + require.NoError(ts.T(), ts.Config.Hook.SendSMS.PopulateExtensibilityPoint()) + require.NoError(ts.T(), ts.API.db.RawQuery(` + create or replace function send_sms_mfa_mock(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`).Exec()) + + phone := "+1234567" + friendlyName := "testchallengesmsfactor" + + f := models.NewPhoneFactor(ts.TestUser, phone, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(f), "Error creating new SMS factor") + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + channel string + expectedCode int + }{ + { + desc: "SMS Channel", + channel: sms_provider.SMSProvider, + expectedCode: http.StatusOK, + }, + { + desc: "WhatsApp Channel", + channel: sms_provider.WhatsappProvider, + expectedCode: http.StatusOK, + }, + } + + for _, tc := range cases { + ts.Run(tc.desc, func() { + w := performSMSChallengeFlow(ts, f.ID, token, tc.channel) + challengeResp := ChallengeFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + require.Equal(ts.T(), challengeResp.Type, models.Phone) + require.Equal(ts.T(), tc.expectedCode, w.Code, tc.desc) + }) + } +} + +func (ts *MFATestSuite) TestMFAVerifyFactor() { + cases := []struct { + desc string + validChallenge bool + validCode bool + factorType string + expectedHTTPCode int + }{ + { + desc: "Invalid: Valid code and expired challenge", + validChallenge: false, + validCode: true, + factorType: models.TOTP, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Invalid: Invalid code and valid challenge", + validChallenge: true, + validCode: false, + factorType: models.TOTP, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Valid /verify request", + validChallenge: true, + validCode: true, + factorType: models.TOTP, + expectedHTTPCode: http.StatusOK, + }, + { + desc: "Invalid: Valid code and expired challenge (SMS)", + validChallenge: false, + validCode: true, + factorType: models.Phone, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Invalid: Invalid code and valid challenge (SMS)", + validChallenge: true, + validCode: false, + factorType: models.Phone, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Valid /verify request (SMS)", + validChallenge: true, + validCode: true, + factorType: models.Phone, + expectedHTTPCode: http.StatusOK, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + // Authenticate users and set secret + var buffer bytes.Buffer + r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{}) + require.NoError(ts.T(), err) + token := ts.generateAAL1Token(ts.TestUser, r.SessionId) + var f *models.Factor + var sharedSecret string + + if v.factorType == models.TOTP { + friendlyName := uuid.Must(uuid.NewV4()).String() + f = models.NewTOTPFactor(ts.TestUser, friendlyName) + sharedSecret = ts.TestOTPKey.Secret() + f.Secret = sharedSecret + require.NoError(ts.T(), ts.API.db.Create(f), "Error updating new test factor") + } else if v.factorType == models.Phone { + friendlyName := uuid.Must(uuid.NewV4()).String() + numDigits := 10 + otp := crypto.GenerateOtp(numDigits) + require.NoError(ts.T(), err) + phone := fmt.Sprintf("+%s", otp) + f = models.NewPhoneFactor(ts.TestUser, phone, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(f), "Error creating new SMS factor") + } + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + var c *models.Challenge + var code string + if v.factorType == models.TOTP { + c = f.CreateChallenge(utilities.GetIPAddress(req)) + // Verify TOTP code + code, err = totp.GenerateCode(sharedSecret, time.Now().UTC()) + require.NoError(ts.T(), err) + } else if v.factorType == models.Phone { + code = "123456" + c, err = f.CreatePhoneChallenge(utilities.GetIPAddress(req), code, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey) + require.NoError(ts.T(), err) + } + + if !v.validCode && v.factorType == models.TOTP { + code, err = totp.GenerateCode(sharedSecret, time.Now().UTC().Add(-1*time.Minute*time.Duration(1))) + require.NoError(ts.T(), err) + + } else if !v.validCode && v.factorType == models.Phone { + invalidSuffix := "1" + code += invalidSuffix + } + + require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge") + if !v.validChallenge { + // Set challenge creation so that it has expired in present time. + newCreatedAt := time.Now().UTC().Add(-1 * time.Second * time.Duration(ts.Config.MFA.ChallengeExpiryDuration+1)) + // created_at is managed by buffalo(ORM) needs to be raw query to be updated + err := ts.API.db.RawQuery("UPDATE auth.mfa_challenges SET created_at = ? WHERE factor_id = ?", newCreatedAt, f.ID).Exec() + require.NoError(ts.T(), err, "Error updating new test challenge") + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "challenge_id": c.ID, + "code": code, + })) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), v.expectedHTTPCode, w.Code) + + if v.expectedHTTPCode == http.StatusOK { + // Ensure alternate session has been deleted + _, err = models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error()) + } + if !v.validChallenge { + // Ensure invalid challenges are deleted + _, err := f.FindChallengeByID(ts.API.db, c.ID) + require.EqualError(ts.T(), err, models.ChallengeNotFoundError{}.Error()) + } + }) + } +} + +func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { + cases := []struct { + desc string + isAAL2 bool + expectedHTTPCode int + }{ + { + desc: "Verified Factor: AAL1", + isAAL2: false, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Verified Factor: AAL2, Success", + isAAL2: true, + expectedHTTPCode: http.StatusOK, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + var buffer bytes.Buffer + + // Create Session to test behaviour which downgrades other sessions + f := ts.TestUser.Factors[0] + require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified)) + if v.isAAL2 { + ts.TestSession.UpdateAALAndAssociatedFactor(ts.API.db, models.AAL2, &f.ID) + } + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) + require.Equal(ts.T(), v.expectedHTTPCode, w.Code) + + if v.expectedHTTPCode == http.StatusOK { + _, err := models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.Equal(ts.T(), models.AAL1.String(), session.GetAAL()) + require.Nil(ts.T(), session.FactorID) + + } + }) + } + +} + +func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { + var buffer bytes.Buffer + f := ts.TestUser.Factors[0] + + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "factor_id": f.ID, + })) + + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + + _, err := models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.Equal(ts.T(), models.AAL1.String(), session.GetAAL()) + require.Nil(ts.T(), session.FactorID) + +} + +// Integration Tests +func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { + ts.Config.Security.RefreshTokenRotationEnabled = true + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": accessTokenResp.RefreshToken, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + ctx, err := ts.API.parseJWTClaims(data.Token, req) + require.NoError(ts.T(), err) + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + require.NoError(ts.T(), err) + require.True(ts.T(), getSession(ctx).IsAAL2()) +} + +// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session +func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { + ts.Config.Security.RefreshTokenRotationEnabled = true + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": ts.TestEmail, + "password": ts.TestPassword, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + ctx, err := ts.API.parseJWTClaims(data.Token, req) + require.NoError(ts.T(), err) + + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL()) + session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID) + require.NoError(ts.T(), err) + require.True(ts.T(), session.IsAAL2()) +} + +func (ts *MFATestSuite) TestChallengeWebAuthnFactor() { + factor := models.NewWebAuthnFactor(ts.TestUser, "WebAuthnfactor") + validWebAuthnConfiguration := &WebAuthnParams{ + RPID: "localhost", + RPOrigins: "http://localhost:3000", + } + require.NoError(ts.T(), ts.API.db.Create(factor), "Error saving new test factor") + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := performChallengeWebAuthnFlow(ts, factor.ID, token, validWebAuthnConfiguration) + require.Equal(ts.T(), http.StatusOK, w.Code) +} + +func performChallengeWebAuthnFlow(ts *MFATestSuite, factorID uuid.UUID, token string, webauthn *WebAuthnParams) *httptest.ResponseRecorder { + var buffer bytes.Buffer + err := json.NewEncoder(&buffer).Encode(ChallengeFactorParams{WebAuthn: webauthn}) + require.NoError(ts.T(), err) + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w +} + +func (ts *MFATestSuite) TestChallengeFactorNotOwnedByUser() { + var buffer bytes.Buffer + email := "nomfaenabled@test.com" + password := "testpassword" + signUpResp := signUp(ts, email, password) + + friendlyName := "testfactor" + phoneNumber := "+1234567" + + otherUsersPhoneFactor := models.NewPhoneFactor(ts.TestUser, phoneNumber, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(otherUsersPhoneFactor), "Error creating factor") + + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", otherUsersPhoneFactor.ID), signUpResp.Token, buffer) + + expectedError := notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Equal(ts.T(), expectedError.ErrorCode, data.ErrorCode) + require.Equal(ts.T(), http.StatusNotFound, w.Code) + +} + +func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) { + ts.API.config.Mailer.Autoconfirm = true + var buffer bytes.Buffer + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": email, + "password": password, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + return data +} + +func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder { + signUpResp := signUp(ts, email, password) + resp := performEnrollAndVerify(ts, signUpResp.Token, requireStatusOK) + + return resp + +} + +func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, phone string, expectedCode int) *httptest.ResponseRecorder { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(EnrollFactorParams{FriendlyName: friendlyName, FactorType: factorType, Issuer: issuer, Phone: phone})) + w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer) + require.Equal(ts.T(), expectedCode, w.Code) + return w +} + +func ServeAuthenticatedRequest(ts *MFATestSuite, method, path, token string, buffer bytes.Buffer) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + req := httptest.NewRequest(method, path, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + req.Header.Set("Content-Type", "application/json") + + ts.API.handler.ServeHTTP(w, req) + return w +} + +func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, requireStatusOK bool) *httptest.ResponseRecorder { + var buffer bytes.Buffer + + factor, err := models.FindFactorByFactorID(ts.API.db, factorID) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), factor) + + totpSecret := factor.Secret + + if es := crypto.ParseEncryptedString(factor.Secret); es != nil { + secret, err := es.Decrypt(factor.ID.String(), ts.API.config.Security.DBEncryption.DecryptionKeys) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), secret) + + totpSecret = string(secret) + } + + code, err := totp.GenerateCode(totpSecret, time.Now().UTC()) + require.NoError(ts.T(), err) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "challenge_id": challengeID, + "code": code, + })) + + y := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), token, buffer) + + if requireStatusOK { + require.Equal(ts.T(), http.StatusOK, y.Code) + } + return y +} + +func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder { + var buffer bytes.Buffer + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w + +} + +func performSMSChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token, channel string) *httptest.ResponseRecorder { + params := ChallengeFactorParams{ + Channel: channel, + } + var buffer bytes.Buffer + if err := json.NewEncoder(&buffer).Encode(params); err != nil { + panic(err) // handle the error appropriately in real code + } + + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w + +} + +func performEnrollAndVerify(ts *MFATestSuite, token string, requireStatusOK bool) *httptest.ResponseRecorder { + w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, "", http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + factorID := enrollResp.ID + + // Challenge + w = performChallengeFlow(ts, factorID, token) + + challengeResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + challengeID := challengeResp.ID + + // Verify + y := performVerifyFlow(ts, challengeID, factorID, token, requireStatusOK) + + return y +} + +func (ts *MFATestSuite) TestVerificationHooks() { + type verificationHookTestCase struct { + desc string + enabled bool + uri string + hookFunctionSQL string + emailSuffix string + expectToken bool + expectedCode int + cleanupHookFunction string + } + cases := []verificationHookTestCase{ + { + desc: "Default Success", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook", + hookFunctionSQL: ` + create or replace function verification_hook(input jsonb) + returns json as $$ + begin + return json_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + emailSuffix: "success", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook(input jsonb)", + }, + { + desc: "Error", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_error", + hookFunctionSQL: ` + create or replace function test_verification_hook_error(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + emailSuffix: "error", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_error(input jsonb)", + }, + { + desc: "Reject - Enabled", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_enabled", + expectToken: false, + expectedCode: http.StatusForbidden, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Reject - Disabled", + enabled: false, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_disabled", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Timeout", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_timeout", + hookFunctionSQL: ` + create or replace function test_verification_hook_timeout(input jsonb) + returns json as $$ + begin + PERFORM pg_sleep(3); + return json_build_object( + 'decision', 'continue' + ); + end; $$ language plpgsql;`, + emailSuffix: "timeout", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_timeout(input jsonb)", + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled + ts.Config.Hook.MFAVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + email := fmt.Sprintf("testemail_%s@gmail.com", c.emailSuffix) + password := "testpassword" + resp := performTestSignupAndVerify(ts, email, password, c.expectToken) + require.Equal(ts.T(), c.expectedCode, resp.Code) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + if c.expectToken { + require.NotEqual(t, "", accessTokenResp.Token) + } else { + require.Equal(t, "", accessTokenResp.Token) + } + + cleanupHook(ts, c.cleanupHookFunction) + }) + } +} + +func cleanupHook(ts *MFATestSuite, hookName string) { + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName) + err := ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 000000000..2754824fa --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,402 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/security" + "github.com/supabase/auth/internal/utilities" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + jwt "github.com/golang-jwt/jwt/v5" +) + +type FunctionHooks map[string][]string + +type AuthMicroserviceClaims struct { + jwt.RegisteredClaims + SiteURL string `json:"site_url"` + InstanceID string `json:"id"` + FunctionHooks FunctionHooks `json:"function_hooks"` +} + +func (f *FunctionHooks) UnmarshalJSON(b []byte) error { + var raw map[string][]string + err := json.Unmarshal(b, &raw) + if err == nil { + *f = FunctionHooks(raw) + return nil + } + // If unmarshaling into map[string][]string fails, try legacy format. + var legacy map[string]string + err = json.Unmarshal(b, &legacy) + if err != nil { + return err + } + if *f == nil { + *f = make(FunctionHooks) + } + for event, hook := range legacy { + (*f)[event] = []string{hook} + } + return nil +} + +var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") + +func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { + return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { + c := req.Context() + + if limitHeader := a.config.RateLimitHeader; limitHeader != "" { + key := req.Header.Get(limitHeader) + + if key == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") + return c, nil + } else { + err := tollbooth.LimitByKeys(lmt, []string{key}) + if err != nil { + return c, tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") + } + } + } + return c, nil + } +} + +func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { + t, err := a.extractBearerToken(req) + if err != nil || t == "" { + return nil, err + } + + ctx, err := a.parseJWTClaims(t, req) + if err != nil { + return nil, err + } + + return a.requireAdmin(ctx) +} + +func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + if !config.External.Email.Enabled { + return nil, badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + + return ctx, nil +} + +func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + if !config.Security.Captcha.Enabled { + return ctx, nil + } + if _, err := a.requireAdminCredentials(w, req); err == nil { + // skip captcha validation if authorization header contains an admin role + return ctx, nil + } + if shouldIgnore := isIgnoreCaptchaRoute(req); shouldIgnore { + return ctx, nil + } + + body := &security.GotrueRequest{} + if err := retrieveRequestParams(req, body); err != nil { + return nil, err + } + + verificationResult, err := security.VerifyRequest(body, utilities.GetIPAddress(req), strings.TrimSpace(config.Security.Captcha.Secret), config.Security.Captcha.Provider) + if err != nil { + return nil, internalServerError("captcha verification process failed").WithInternalError(err) + } + + if !verificationResult.Success { + return nil, badRequestError(apierrors.ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) + } + + return ctx, nil +} + +func isIgnoreCaptchaRoute(req *http.Request) bool { + // captcha shouldn't be enabled on the following grant_types + // id_token, refresh_token, pkce + if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" { + return true + } + return false +} + +func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + xForwardedHost := req.Header.Get("X-Forwarded-Host") + xForwardedProto := req.Header.Get("X-Forwarded-Proto") + reqHost := req.URL.Hostname() + + if len(config.Mailer.ExternalHosts) > 0 { + // this server is configured to accept multiple external hosts, validate the host from the X-Forwarded-Host or Host headers + + hostname := "" + protocol := "https" + + if xForwardedHost != "" { + for _, host := range config.Mailer.ExternalHosts { + if host == xForwardedHost { + hostname = host + break + } + } + } else if reqHost != "" { + for _, host := range config.Mailer.ExternalHosts { + if host == reqHost { + hostname = host + break + } + } + } + + if hostname != "" { + if hostname == "localhost" { + // allow the use of HTTP only if the accepted hostname was localhost + if xForwardedProto == "http" || req.URL.Scheme == "http" { + protocol = "http" + } + } + + externalHostURL, err := url.ParseRequestURI(fmt.Sprintf("%s://%s", protocol, hostname)) + if err != nil { + return ctx, err + } + + return withExternalHost(ctx, externalHostURL), nil + } + } + + if xForwardedHost != "" || reqHost != "" { + // host has been provided to the request, but it hasn't been + // added to the allow list, raise a log message + // in Supabase platform the X-Forwarded-Host and full request + // URL are likely sanitzied before they reach the server + + fields := make(logrus.Fields) + + if xForwardedHost != "" { + fields["x_forwarded_host"] = xForwardedHost + } + + if xForwardedProto != "" { + fields["x_forwarded_proto"] = xForwardedProto + } + + if reqHost != "" { + fields["request_url_host"] = reqHost + + if req.URL.Scheme != "" { + fields["request_url_scheme"] = req.URL.Scheme + } + } + + logrus.WithFields(fields).Info("Request received external host in X-Forwarded-Host or Host headers, but the values have not been added to GOTRUE_MAILER_EXTERNAL_HOSTS and will not be used. To suppress this message add the host, or sanitize the headers before the request reaches Auth.") + } + + // either the provided external hosts don't match the allow list, or + // the server is not configured to accept multiple hosts -- use the + // configured external URL instead + + externalHostURL, err := url.ParseRequestURI(config.API.ExternalURL) + if err != nil { + return ctx, err + } + + return withExternalHost(ctx, externalHostURL), nil +} + +func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.SAML.Enabled { + return nil, notFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") + } + return ctx, nil +} + +func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.Security.ManualLinkingEnabled { + return nil, notFoundError(apierrors.ErrorCodeManualLinkingDisabled, "Manual linking is disabled") + } + return ctx, nil +} + +func (a *API) databaseCleanup(cleanup models.Cleaner) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedResp := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(wrappedResp, r) + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + if (wrappedResp.Status() / 100) != 2 { + // don't do any cleanups for non-2xx responses + return + } + // continue + default: + return + } + + db := a.db.WithContext(r.Context()) + log := observability.GetLogEntry(r).Entry + + affectedRows, err := cleanup.Clean(db) + if err != nil { + log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") + } else if affectedRows > 0 { + log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") + } + }) + } +} + +// timeoutResponseWriter is a http.ResponseWriter that queues up a response +// body to be sent if the serving completes before the context has exceeded its +// deadline. +type timeoutResponseWriter struct { + sync.Mutex + + header http.Header + wroteHeader bool + snapHeader http.Header // snapshot of the header at the time WriteHeader was called + statusCode int + buf bytes.Buffer +} + +func (t *timeoutResponseWriter) Header() http.Header { + t.Lock() + defer t.Unlock() + + return t.header +} + +func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { + t.Lock() + defer t.Unlock() + + if !t.wroteHeader { + t.writeHeaderLocked(http.StatusOK) + } + + return t.buf.Write(bytes) +} + +func (t *timeoutResponseWriter) WriteHeader(statusCode int) { + t.Lock() + defer t.Unlock() + + t.writeHeaderLocked(statusCode) +} + +func (t *timeoutResponseWriter) writeHeaderLocked(statusCode int) { + if t.wroteHeader { + // ignore multiple calls to WriteHeader + // once WriteHeader has been called once, a snapshot of the header map is taken + // and saved in snapHeader to be used in finallyWrite + return + } + + t.statusCode = statusCode + t.wroteHeader = true + t.snapHeader = t.header.Clone() +} + +func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) { + t.Lock() + defer t.Unlock() + + dst := w.Header() + for k, vv := range t.snapHeader { + dst[k] = vv + } + + if !t.wroteHeader { + t.statusCode = http.StatusOK + } + + w.WriteHeader(t.statusCode) + if _, err := w.Write(t.buf.Bytes()); err != nil { + logrus.WithError(err).Warn("Write failed") + } +} + +func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + timeoutWriter := &timeoutResponseWriter{ + header: make(http.Header), + } + + panicChan := make(chan any, 1) + serverDone := make(chan struct{}) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + close(serverDone) + }() + + select { + case p := <-panicChan: + panic(p) + + case <-serverDone: + timeoutWriter.finallyWrite(w) + + case <-ctx.Done(): + err := ctx.Err() + + if err == context.DeadlineExceeded { + httpError := &HTTPError{ + HTTPStatus: http.StatusGatewayTimeout, + ErrorCode: apierrors.ErrorCodeRequestTimeout, + Message: "Processing this request timed out, please retry after a moment.", + } + + httpError = httpError.WithInternalError(err) + + HandleResponseError(httpError, w, r) + } else { + // unrecognized context error, so we should wait for the server to finish + // and write out the response + <-serverDone + + timeoutWriter.finallyWrite(w) + } + } + }) + } +} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go new file mode 100644 index 000000000..b0f319c18 --- /dev/null +++ b/internal/api/middleware_test.go @@ -0,0 +1,511 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +const ( + HCaptchaSecret string = "0x0000000000000000000000000000000000000000" + CaptchaResponse string = "10000000-aaaa-bbbb-cccc-000000000001" + TurnstileCaptchaSecret string = "1x0000000000000000000000000000000AA" +) + +type MiddlewareTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestMiddlewareFunctions(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &MiddlewareTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *MiddlewareTestSuite) TestVerifyCaptchaValid() { + ts.Config.Security.Captcha.Enabled = true + + adminClaims := &AccessTokenClaims{ + Role: "supabase_admin", + } + adminJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, adminClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + cases := []struct { + desc string + adminJwt string + captcha_token string + captcha_provider string + }{ + { + "Valid captcha response", + "", + CaptchaResponse, + "hcaptcha", + }, + { + "Valid captcha response", + "", + CaptchaResponse, + "turnstile", + }, + { + "Ignore captcha if admin role is present", + adminJwt, + "", + "hcaptcha", + }, + { + "Ignore captcha if admin role is present", + adminJwt, + "", + "turnstile", + }, + } + for _, c := range cases { + ts.Config.Security.Captcha.Provider = c.captcha_provider + if c.captcha_provider == "turnstile" { + ts.Config.Security.Captcha.Secret = TurnstileCaptchaSecret + } else if c.captcha_provider == "hcaptcha" { + ts.Config.Security.Captcha.Secret = HCaptchaSecret + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": c.captcha_token, + }, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Set("Content-Type", "application/json") + if c.adminJwt != "" { + req.Header.Set("Authorization", "Bearer "+c.adminJwt) + } + + beforeCtx := context.Background() + req = req.WithContext(beforeCtx) + + w := httptest.NewRecorder() + + afterCtx, err := ts.API.verifyCaptcha(w, req) + require.NoError(ts.T(), err) + + body, err := io.ReadAll(req.Body) + require.NoError(ts.T(), err) + + // re-initialize buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": c.captcha_token, + }, + })) + + // check if body is the same + require.Equal(ts.T(), body, buffer.Bytes()) + require.Equal(ts.T(), afterCtx, beforeCtx) + } +} + +func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { + cases := []struct { + desc string + captchaConf *conf.CaptchaConfiguration + expectedCode int + expectedMsg string + }{ + { + "Captcha validation failed", + &conf.CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "test", + }, + http.StatusBadRequest, + "captcha protection: request disallowed (not-using-dummy-secret)", + }, + { + "Captcha validation failed", + &conf.CaptchaConfiguration{ + Enabled: true, + Provider: "turnstile", + Secret: "anothertest", + }, + http.StatusBadRequest, + "captcha protection: request disallowed (invalid-input-secret)", + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.Captcha = *c.captchaConf + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": CaptchaResponse, + }, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Set("Content-Type", "application/json") + + req = req.WithContext(context.Background()) + + w := httptest.NewRecorder() + + _, err := ts.API.verifyCaptcha(w, req) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) + require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) + }) + } +} + +func (ts *MiddlewareTestSuite) TestIsValidExternalHost() { + cases := []struct { + desc string + externalHosts []string + + requestURL string + headers http.Header + + expectedURL string + }{ + { + desc: "no defined external hosts, no headers, no absolute request URL", + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "no defined external hosts, unauthorized X-Forwarded-Host without any external hosts", + headers: http.Header{ + "X-Forwarded-Host": []string{ + "external-host.com", + }, + }, + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, unauthorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"https"}, + "X-Forwarded-Host": []string{ + "external-host.com", + }, + }, + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "no defined external hosts, unauthorized Host", + requestURL: "https://external-host.com/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, unauthorized Host", + externalHosts: []string{"authorized-host.com"}, + requestURL: "https://external-host.com/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, authorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS + "X-Forwarded-Host": []string{ + "authorized-host.com", + }, + }, + requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized Host", + externalHosts: []string{"authorized-host.com"}, + requestURL: "https://authorized-host.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS + "X-Forwarded-Host": []string{ + "authorized-host.com", + }, + }, + requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized localhost in X-Forwarded-Host with HTTP", + externalHosts: []string{"localhost"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, + "X-Forwarded-Host": []string{ + "localhost", + }, + }, + requestURL: "/some-path", + expectedURL: "http://localhost", + }, + + { + desc: "defined external hosts, authorized localhost in Host with HTTP", + externalHosts: []string{"localhost"}, + requestURL: "http://localhost:3000/some-path", + expectedURL: "http://localhost", + }, + } + + require.NotEmpty(ts.T(), ts.API.config.API.ExternalURL) + + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(http.MethodPost, c.requestURL, nil) + if c.headers != nil { + req.Header = c.headers + } + + originalHosts := ts.API.config.Mailer.ExternalHosts + ts.API.config.Mailer.ExternalHosts = c.externalHosts + + w := httptest.NewRecorder() + ctx, err := ts.API.isValidExternalHost(w, req) + + ts.API.config.Mailer.ExternalHosts = originalHosts + + require.NoError(ts.T(), err) + + externalURL := getExternalHost(ctx) + require.Equal(ts.T(), c.expectedURL, externalURL.String()) + }) + } +} + +func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { + cases := []struct { + desc string + isEnabled bool + expectedErr error + }{ + { + desc: "SAML not enabled", + isEnabled: false, + expectedErr: notFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), + }, + { + desc: "SAML enabled", + isEnabled: true, + expectedErr: nil, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.SAML.Enabled = c.isEnabled + req := httptest.NewRequest("GET", "http://localhost", nil) + w := httptest.NewRecorder() + + _, err := ts.API.requireSAMLEnabled(w, req) + require.Equal(ts.T(), c.expectedErr, err) + }) + } +} + +func TestFunctionHooksUnmarshalJSON(t *testing.T) { + tests := []struct { + in string + ok bool + }{ + {`{ "signup" : "identity-signup" }`, true}, + {`{ "signup" : ["identity-signup"] }`, true}, + {`{ "signup" : {"foo" : "bar"} }`, false}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + var f FunctionHooks + err := json.Unmarshal([]byte(tt.in), &f) + if tt.ok { + assert.NoError(t, err) + assert.Equal(t, FunctionHooks{"signup": {"identity-signup"}}, f) + } else { + assert.Error(t, err) + } + }) + } +} + +func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { + ts.Config.API.MaxRequestDuration = 5 * time.Microsecond + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w := httptest.NewRecorder() + + timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration) + + slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Sleep for 1 second to simulate a slow handler which should trigger the timeout + time.Sleep(1 * time.Second) + ts.API.handler.ServeHTTP(w, r) + }) + timeoutHandler(slowHandler).ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), apierrors.ErrorCodeRequestTimeout, data["error_code"]) + require.Equal(ts.T(), float64(504), data["code"]) + require.NotNil(ts.T(), data["msg"]) +} + +func TestTimeoutResponseWriter(t *testing.T) { + // timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w1 := httptest.NewRecorder() + w2 := httptest.NewRecorder() + + timeoutHandler := timeoutMiddleware(time.Second * 10) + + redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // tries to redirect twice + http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther) + + // overwrites the first + http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther) + }) + timeoutHandler(redirectHandler).ServeHTTP(w1, req) + redirectHandler.ServeHTTP(w2, req) + + require.Equal(t, w1.Result(), w2.Result()) +} + +func (ts *MiddlewareTestSuite) TestLimitHandler() { + ts.Config.RateLimitHeader = "X-Rate-Limit" + lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }) + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), "ok", data["message"]) + } + + // 6th request should fail and return a rate limit exceeded error + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) +} + +type MockCleanup struct { + mock.Mock +} + +func (m *MockCleanup) Clean(db *storage.Connection) (int, error) { + m.Called(db) + return 0, nil +} + +func (ts *MiddlewareTestSuite) TestDatabaseCleanup() { + testHandler := func(statusCode int) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + } + + cases := []struct { + desc string + statusCode int + method string + }{ + { + desc: "Run cleanup successfully", + statusCode: http.StatusOK, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if GET", + statusCode: http.StatusOK, + method: http.MethodGet, + }, + { + desc: "Skip cleanup if 3xx", + statusCode: http.StatusSeeOther, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 4xx", + statusCode: http.StatusBadRequest, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 5xx", + statusCode: http.StatusInternalServerError, + method: http.MethodPost, + }, + } + + mockCleanup := new(MockCleanup) + mockCleanup.On("Clean", mock.Anything).Return(0, nil) + for _, c := range cases { + ts.Run("DatabaseCleanup", func() { + req := httptest.NewRequest(c.method, "http://localhost", nil) + w := httptest.NewRecorder() + ts.API.databaseCleanup(mockCleanup)(testHandler(c.statusCode)).ServeHTTP(w, req) + require.Equal(ts.T(), c.statusCode, w.Code) + }) + } + mockCleanup.AssertNumberOfCalls(ts.T(), "Clean", 1) +} diff --git a/internal/api/opentelemetry-tracer_test.go b/internal/api/opentelemetry-tracer_test.go new file mode 100644 index 000000000..4aeddce5a --- /dev/null +++ b/internal/api/opentelemetry-tracer_test.go @@ -0,0 +1,93 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + semconv "go.opentelemetry.io/otel/semconv/v1.25.0" +) + +type OpenTelemetryTracerTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestOpenTelemetryTracer(t *testing.T) { + api, config, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) { + if config != nil { + config.Tracing.Enabled = true + config.Tracing.Exporter = conf.OpenTelemetryTracing + } + }) + + require.NoError(t, err) + + ts := &OpenTelemetryTracerTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func getAttribute(attributes []attribute.KeyValue, key attribute.Key) *attribute.Value { + for _, value := range attributes { + if value.Key == key { + return &value.Value + } + } + + return nil +} + +func (ts *OpenTelemetryTracerTestSuite) TestOpenTelemetryTracer_Spans() { + exporter := tracetest.NewInMemoryExporter() + bsp := sdktrace.NewSimpleSpanProcessor(exporter) + traceProvider := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.AlwaysSample()), + sdktrace.WithSpanProcessor(bsp), + ) + otel.SetTracerProvider(traceProvider) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "http://localhost/something1", nil) + req.Header.Set("User-Agent", "whatever") + ts.API.handler.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "http://localhost/something2", nil) + req.Header.Set("User-Agent", "whatever") + ts.API.handler.ServeHTTP(w, req) + + spanStubs := exporter.GetSpans() + spans := spanStubs.Snapshots() + + if assert.Equal(ts.T(), 2, len(spans)) { + attributes1 := spans[0].Attributes() + method1 := getAttribute(attributes1, semconv.HTTPMethodKey) + assert.Equal(ts.T(), "POST", method1.AsString()) + url1 := getAttribute(attributes1, semconv.HTTPTargetKey) + assert.Equal(ts.T(), "/something1", url1.AsString()) + statusCode1 := getAttribute(attributes1, semconv.HTTPStatusCodeKey) + assert.Equal(ts.T(), int64(404), statusCode1.AsInt64()) + + attributes2 := spans[1].Attributes() + method2 := getAttribute(attributes2, semconv.HTTPMethodKey) + assert.Equal(ts.T(), "GET", method2.AsString()) + url2 := getAttribute(attributes2, semconv.HTTPTargetKey) + assert.Equal(ts.T(), "/something2", url2.AsString()) + statusCode2 := getAttribute(attributes2, semconv.HTTPStatusCodeKey) + assert.Equal(ts.T(), int64(404), statusCode2.AsInt64()) + } +} diff --git a/internal/api/options.go b/internal/api/options.go new file mode 100644 index 000000000..9053c2f97 --- /dev/null +++ b/internal/api/options.go @@ -0,0 +1,102 @@ +package api + +import ( + "time" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/ratelimit" +) + +type Option interface { + apply(*API) +} + +type LimiterOptions struct { + Email ratelimit.Limiter + Phone ratelimit.Limiter + + Signups *limiter.Limiter + AnonymousSignIns *limiter.Limiter + Recover *limiter.Limiter + Resend *limiter.Limiter + MagicLink *limiter.Limiter + Otp *limiter.Limiter + Token *limiter.Limiter + Verify *limiter.Limiter + User *limiter.Limiter + FactorVerify *limiter.Limiter + FactorChallenge *limiter.Limiter + SSO *limiter.Limiter + SAMLAssertion *limiter.Limiter +} + +func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } + +func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { + o := &LimiterOptions{} + + o.Email = ratelimit.New(gc.RateLimitEmailSent) + o.Phone = ratelimit.New(gc.RateLimitSmsSent) + + o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) + + o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + // These all use the OTP limit per 5 min with 1hour ttl and burst of 30. + o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp) + return o +} + +func newLimiterPer5mOver1h(rate float64) *limiter.Limiter { + freq := rate / (60 * 5) + lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + return lim +} diff --git a/internal/api/options_test.go b/internal/api/options_test.go new file mode 100644 index 000000000..c4c1d1623 --- /dev/null +++ b/internal/api/options_test.go @@ -0,0 +1,30 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/supabase/auth/internal/conf" +) + +func TestNewLimiterOptions(t *testing.T) { + cfg := &conf.GlobalConfiguration{} + cfg.ApplyDefaults() + + rl := NewLimiterOptions(cfg) + assert.NotNil(t, rl.Email) + assert.NotNil(t, rl.Phone) + assert.NotNil(t, rl.Signups) + assert.NotNil(t, rl.AnonymousSignIns) + assert.NotNil(t, rl.Recover) + assert.NotNil(t, rl.Resend) + assert.NotNil(t, rl.MagicLink) + assert.NotNil(t, rl.Otp) + assert.NotNil(t, rl.Token) + assert.NotNil(t, rl.Verify) + assert.NotNil(t, rl.User) + assert.NotNil(t, rl.FactorVerify) + assert.NotNil(t, rl.FactorChallenge) + assert.NotNil(t, rl.SSO) + assert.NotNil(t, rl.SAMLAssertion) +} diff --git a/internal/api/otp.go b/internal/api/otp.go new file mode 100644 index 000000000..916d61deb --- /dev/null +++ b/internal/api/otp.go @@ -0,0 +1,238 @@ +package api + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// OtpParams contains the request body params for the otp endpoint +type OtpParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + CreateUser bool `json:"create_user"` + Data map[string]interface{} `json:"data"` + Channel string `json:"channel"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +// SmsParams contains the request body params for sms otp +type SmsParams struct { + Phone string `json:"phone"` + Channel string `json:"channel"` + Data map[string]interface{} `json:"data"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (p *OtpParams) Validate() error { + if p.Email != "" && p.Phone != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided") + } + if p.Email != "" && p.Channel != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +func (p *SmsParams) Validate(config *conf.GlobalConfiguration) error { + var err error + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + if !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + } + return nil +} + +// Otp returns the MagicLink or SmsOtp handler based on the request body params +func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { + params := &OtpParams{ + CreateUser: true, + } + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.Validate(); err != nil { + return err + } + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + if ok, err := a.shouldCreateUser(r, params); !ok { + return unprocessableEntityError(apierrors.ErrorCodeOTPDisabled, "Signups not allowed for otp") + } else if err != nil { + return err + } + + if params.Email != "" { + return a.MagicLink(w, r) + } else if params.Phone != "" { + return a.SmsOtp(w, r) + } + + return badRequestError(apierrors.ErrorCodeValidationFailed, "One of email or phone must be set") +} + +type SmsOtpResponse struct { + MessageID string `json:"message_id,omitempty"` +} + +// SmsOtp sends the user an otp via sms +func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + if !config.External.Phone.Enabled { + return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Unsupported phone provider") + } + var err error + + params := &SmsParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + // For backwards compatibility, we default to SMS if params Channel is not specified + if params.Phone != "" && params.Channel == "" { + params.Channel = sms_provider.SMSProvider + } + + if err := params.Validate(config); err != nil { + return err + } + + var isNewUser bool + aud := a.requestAud(ctx, r) + user, err := models.FindUserByPhoneAndAudience(db, params.Phone, aud) + if err != nil { + if models.IsNotFoundError(err) { + isNewUser = true + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + if user != nil { + isNewUser = !user.IsPhoneConfirmed() + } + if isNewUser { + // User either doesn't exist or hasn't completed the signup process. + // Sign them up with temporary password. + password, err := password.Generate(64, 10, 1, false, true) + if err != nil { + return internalServerError("error creating user").WithInternalError(err) + } + + signUpParams := &SignupParams{ + Phone: params.Phone, + Password: password, + Data: params.Data, + Channel: params.Channel, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must be marshallable + panic(err) + } + r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) + + fakeResponse := &responseStub{} + + if config.Sms.Autoconfirm { + // signups are autoconfirmed, send otp after signup + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + + signUpParams := &SignupParams{ + Phone: params.Phone, + Channel: params.Channel, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must be marshallable + panic(err) + } + r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) + return a.SmsOtp(w, r) + } + + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + return sendJSON(w, http.StatusOK, make(map[string]string)) + } + + messageID := "" + err = db.Transaction(func(tx *storage.Connection) error { + if err := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", map[string]interface{}{ + "channel": params.Channel, + }); err != nil { + return err + } + mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, params.Channel) + if serr != nil { + return serr + } + messageID = mID + return nil + }) + + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, SmsOtpResponse{ + MessageID: messageID, + }) +} + +func (a *API) shouldCreateUser(r *http.Request, params *OtpParams) (bool, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + if !params.CreateUser { + ctx := r.Context() + aud := a.requestAud(ctx, r) + var err error + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return false, err + } + _, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return false, err + } + _, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } + + if err != nil && models.IsNotFoundError(err) { + return false, nil + } + } + return true, nil +} diff --git a/internal/api/otp_test.go b/internal/api/otp_test.go new file mode 100644 index 000000000..7a99f3d9c --- /dev/null +++ b/internal/api/otp_test.go @@ -0,0 +1,312 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type OtpTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestOtp(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &OtpTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *OtpTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + +} + +func (ts *OtpTestSuite) TestOtpPKCE() { + ts.Config.External.Phone.Enabled = true + testCodeChallenge := "testtesttesttesttesttesttestteststeststesttesttesttest" + + var buffer bytes.Buffer + cases := []struct { + desc string + params OtpParams + expected struct { + code int + response map[string]interface{} + } + }{ + { + desc: "Test (PKCE) Success Magiclink Otp", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusOK, + make(map[string]interface{}), + }, + }, + { + desc: "Test (PKCE) Failure, no code challenge", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": apierrors.ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", + }, + }, + }, + { + desc: "Test (PKCE) Failure, no code challenge method", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": apierrors.ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", + }, + }, + }, + { + desc: "Test (PKCE) Success, phone with valid params", + params: OtpParams{ + Phone: "123456789", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusInternalServerError, + map[string]interface{}{ + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", + }, + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.expected.code, w.Code) + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + }) + } +} + +func (ts *OtpTestSuite) TestOtp() { + // Configured to allow testing of invalid channel params + ts.Config.External.Phone.Enabled = true + cases := []struct { + desc string + params OtpParams + expected struct { + code int + response map[string]interface{} + } + }{ + { + desc: "Test Success Magiclink Otp", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + Data: map[string]interface{}{ + "somedata": "metadata", + }, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusOK, + make(map[string]interface{}), + }, + }, + { + desc: "Test Failure Pass Both Email & Phone", + params: OtpParams{ + Email: "test@example.com", + Phone: "123456789", + CreateUser: true, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": apierrors.ErrorCodeValidationFailed, + "msg": "Only an email address or phone number should be provided", + }, + }, + }, + { + desc: "Test Failure invalid channel param", + params: OtpParams{ + Phone: "123456789", + Channel: "invalidchannel", + CreateUser: true, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": apierrors.ErrorCodeValidationFailed, + "msg": InvalidChannelError, + }, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.expected.code, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, c.expected.response) + }) + } +} + +func (ts *OtpTestSuite) TestNoSignupsForOtp() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "newuser@example.com", + "create_user": false, + })) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, map[string]interface{}{ + "code": float64(http.StatusUnprocessableEntity), + "error_code": apierrors.ErrorCodeOTPDisabled, + "msg": "Signups not allowed for otp", + }) +} + +func (ts *OtpTestSuite) TestSubsequentOtp() { + ts.Config.SMTP.MaxFrequency = 0 + userEmail := "foo@example.com" + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": userEmail, + })) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + newUser, err := models.FindUserByEmailAndAudience(ts.API.db, userEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), newUser.ConfirmationToken) + require.NotEmpty(ts.T(), newUser.ConfirmationSentAt) + require.Empty(ts.T(), newUser.RecoveryToken) + require.Empty(ts.T(), newUser.RecoverySentAt) + require.Empty(ts.T(), newUser.EmailConfirmedAt) + + // since the signup process hasn't been completed, + // subsequent requests for another magiclink should not create a recovery token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": userEmail, + })) + + req = httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + user, err := models.FindUserByEmailAndAudience(ts.API.db, userEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user.ConfirmationToken) + require.NotEmpty(ts.T(), user.ConfirmationSentAt) + require.Empty(ts.T(), user.RecoveryToken) + require.Empty(ts.T(), user.RecoverySentAt) + require.Empty(ts.T(), user.EmailConfirmedAt) +} diff --git a/api/pagination.go b/internal/api/pagination.go similarity index 96% rename from api/pagination.go rename to internal/api/pagination.go index 1c2341081..386f40310 100644 --- a/api/pagination.go +++ b/internal/api/pagination.go @@ -6,7 +6,7 @@ import ( "net/url" "strconv" - "github.com/netlify/gotrue/models" + "github.com/supabase/auth/internal/models" ) const defaultPerPage = 50 diff --git a/internal/api/password.go b/internal/api/password.go new file mode 100644 index 000000000..78291a49f --- /dev/null +++ b/internal/api/password.go @@ -0,0 +1,74 @@ +package api + +import ( + "context" + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" +) + +// BCrypt hashed passwords have a 72 character limit +const MaxPasswordLength = 72 + +// WeakPasswordError encodes an error that a password does not meet strength +// requirements. It is handled specially in errors.go as it gets transformed to +// a HTTPError with a special weak_password field that encodes the Reasons +// slice. +type WeakPasswordError struct { + Message string `json:"message,omitempty"` + Reasons []string `json:"reasons,omitempty"` +} + +func (e *WeakPasswordError) Error() string { + return e.Message +} + +func (a *API) checkPasswordStrength(ctx context.Context, password string) error { + config := a.config + + if len(password) > MaxPasswordLength { + return badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Password cannot be longer than %v characters", MaxPasswordLength)) + } + + var messages, reasons []string + + if len(password) < config.Password.MinLength { + reasons = append(reasons, "length") + messages = append(messages, fmt.Sprintf("Password should be at least %d characters.", config.Password.MinLength)) + } + + for _, characterSet := range config.Password.RequiredCharacters { + if characterSet != "" && !strings.ContainsAny(password, characterSet) { + reasons = append(reasons, "characters") + + messages = append(messages, fmt.Sprintf("Password should contain at least one character of each: %s.", strings.Join(config.Password.RequiredCharacters, ", "))) + + break + } + } + + if config.Password.HIBP.Enabled { + pwned, err := a.hibpClient.Check(ctx, password) + if err != nil { + if config.Password.HIBP.FailClosed { + return internalServerError("Unable to perform password strength check with HaveIBeenPwned.org.").WithInternalError(err) + } else { + logrus.WithError(err).Warn("Unable to perform password strength check with HaveIBeenPwned.org, pwned passwords are being allowed") + } + } else if pwned { + reasons = append(reasons, "pwned") + messages = append(messages, "Password is known to be weak and easy to guess, please choose a different one.") + } + } + + if len(reasons) > 0 { + return &WeakPasswordError{ + Message: strings.Join(messages, " "), + Reasons: reasons, + } + } + + return nil +} diff --git a/internal/api/password_test.go b/internal/api/password_test.go new file mode 100644 index 000000000..48e4b097f --- /dev/null +++ b/internal/api/password_test.go @@ -0,0 +1,118 @@ +package api + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" +) + +func TestPasswordStrengthChecks(t *testing.T) { + examples := []struct { + MinLength int + RequiredCharacters []string + + Password string + Reasons []string + }{ + { + MinLength: 6, + Password: "12345", + Reasons: []string{ + "length", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "a123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "ab123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "c123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "abc123", + Reasons: nil, + }, + { + MinLength: 6, + RequiredCharacters: []string{}, + Password: "zZgXb5gzyCNrV36qwbOSbKVQsVJd28mC1TwRpeB0y6sFNICJyjD6bILKJMsjyKDzBdaY5tmi8zY9BWJYmt3vULLmyafjIDLYjy8qhETu0mS2jj1uQBgSAzJn9Zjm8EFa", + Reasons: nil, + }, + } + + for i, example := range examples { + api := &API{ + config: &conf.GlobalConfiguration{ + Password: conf.PasswordConfiguration{ + MinLength: example.MinLength, + RequiredCharacters: conf.PasswordRequiredCharacters(example.RequiredCharacters), + }, + }, + } + + err := api.checkPasswordStrength(context.Background(), example.Password) + + switch e := err.(type) { + case *WeakPasswordError: + require.Equal(t, e.Reasons, example.Reasons, "Example %d failed with wrong reasons", i) + case *HTTPError: + require.Equal(t, e.ErrorCode, apierrors.ErrorCodeValidationFailed, "Example %d failed with wrong error code", i) + default: + require.NoError(t, err, "Example %d failed with error", i) + } + } +} diff --git a/internal/api/phone.go b/internal/api/phone.go new file mode 100644 index 000000000..2f9f83bce --- /dev/null +++ b/internal/api/phone.go @@ -0,0 +1,170 @@ +package api + +import ( + "bytes" + "net/http" + "regexp" + "strings" + "text/template" + "time" + + "github.com/supabase/auth/internal/hooks" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +var e164Format = regexp.MustCompile("^[1-9][0-9]{1,14}$") + +const ( + phoneConfirmationOtp = "confirmation" + phoneReauthenticationOtp = "reauthentication" +) + +func validatePhone(phone string) (string, error) { + phone = formatPhoneNumber(phone) + if isValid := validateE164Format(phone); !isValid { + return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + return phone, nil +} + +// validateE164Format checks if phone number follows the E.164 format +func validateE164Format(phone string) bool { + return e164Format.MatchString(phone) +} + +// formatPhoneNumber removes "+" and whitespaces in a phone number +func formatPhoneNumber(phone string) string { + return strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") +} + +// sendPhoneConfirmation sends an otp to the user's phone number +func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { + config := a.config + + var token *string + var sentAt *time.Time + + includeFields := []string{} + switch otpType { + case phoneChangeVerification: + token = &user.PhoneChangeToken + sentAt = user.PhoneChangeSentAt + user.PhoneChange = phone + includeFields = append(includeFields, "phone_change", "phone_change_token", "phone_change_sent_at") + case phoneConfirmationOtp: + token = &user.ConfirmationToken + sentAt = user.PhoneConfirmationSentAt + includeFields = append(includeFields, "confirmation_token", "phone_confirmation_sent_at") + case phoneReauthenticationOtp: + token = &user.ReauthenticationToken + sentAt = user.ReauthenticationSentAt + includeFields = append(includeFields, "reauthentication_token", "reauthentication_sent_at") + default: + return "", internalServerError("invalid otp type") + } + + // intentionally keeping this before the test OTP, so that the behavior + // of regular and test OTPs is similar + if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { + return "", tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) + } + + now := time.Now() + + var otp, messageID string + + if testOTP, ok := config.Sms.GetTestOTP(phone, now); ok { + otp = testOTP + messageID = "test-otp" + } + + // not using test OTPs + if otp == "" { + // TODO(km): Deprecate this behaviour - rate limits should still be applied to autoconfirm + if !config.Sms.Autoconfirm { + // apply rate limiting before the sms is sent out + if ok := a.limiterOpts.Phone.Allow(); !ok { + return "", tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") + } + } + otp = crypto.GenerateOtp(config.Sms.OtpLength) + + if config.Hook.SendSMS.Enabled { + input := hooks.SendSMSInput{ + User: user, + SMS: hooks.SMS{ + OTP: otp, + }, + } + output := hooks.SendSMSOutput{} + err := a.invokeHook(tx, r, &input, &output) + if err != nil { + return "", err + } + } else { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return "", internalServerError("Unable to get SMS provider").WithInternalError(err) + } + message, err := generateSMSFromTemplate(config.Sms.SMSTemplate, otp) + if err != nil { + return "", internalServerError("error generating sms template").WithInternalError(err) + } + messageID, err := smsProvider.SendMessage(phone, message, channel, otp) + if err != nil { + return messageID, unprocessableEntityError(apierrors.ErrorCodeSMSSendFailed, "Error sending %s OTP to provider: %v", otpType, err) + } + } + } + + *token = crypto.GenerateTokenHash(phone, otp) + + switch otpType { + case phoneConfirmationOtp: + user.PhoneConfirmationSentAt = &now + case phoneChangeVerification: + user.PhoneChangeSentAt = &now + case phoneReauthenticationOtp: + user.ReauthenticationSentAt = &now + } + + if err := tx.UpdateOnly(user, includeFields...); err != nil { + return messageID, errors.Wrap(err, "Database error updating user for phone") + } + + var ottErr error + switch otpType { + case phoneConfirmationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating confirmation token for phone") + } + case phoneChangeVerification: + if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating phone change token") + } + case phoneReauthenticationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating reauthentication token for phone") + } + } + if ottErr != nil { + return messageID, internalServerError("error creating one time token").WithInternalError(ottErr) + } + return messageID, nil +} + +func generateSMSFromTemplate(SMSTemplate *template.Template, otp string) (string, error) { + var message bytes.Buffer + if err := SMSTemplate.Execute(&message, struct { + Code string + }{Code: otp}); err != nil { + return "", err + } + return message.String(), nil +} diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go new file mode 100644 index 000000000..adc50f1a9 --- /dev/null +++ b/internal/api/phone_test.go @@ -0,0 +1,443 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type PhoneTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +type TestSmsProvider struct { + mock.Mock + + SentMessages int +} + +func (t *TestSmsProvider) SendMessage(phone, message, channel, otp string) (string, error) { + t.SentMessages += 1 + return "", nil +} +func (t *TestSmsProvider) VerifyOTP(phone, otp string) error { + return nil +} + +func TestPhone(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &PhoneTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *PhoneTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("123456789", "", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *PhoneTestSuite) TestValidateE164Format() { + isValid := validateE164Format("0123456789") + assert.Equal(ts.T(), false, isValid) +} + +func (ts *PhoneTestSuite) TestFormatPhoneNumber() { + actual := formatPhoneNumber("+1 23456789 ") + assert.Equal(ts.T(), "123456789", actual) +} + +func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + cases := []struct { + desc string + otpType string + expected error + }{ + { + desc: "send confirmation otp", + otpType: phoneConfirmationOtp, + expected: nil, + }, + { + desc: "send phone_change otp", + otpType: phoneChangeVerification, + expected: nil, + }, + { + desc: "send recovery otp", + otpType: phoneReauthenticationOtp, + expected: nil, + }, + { + desc: "send invalid otp type ", + otpType: "invalid otp type", + expected: internalServerError("invalid otp type"), + }, + } + + if useTestOTP { + ts.API.config.Sms.TestOTP = map[string]string{ + "123456789": "123456", + } + } else { + ts.API.config.Sms.TestOTP = nil + } + + for _, c := range cases { + ts.Run(c.desc, func() { + provider := &TestSmsProvider{} + sms_provider.MockProvider = provider + + _, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, sms_provider.SMSProvider) + require.Equal(ts.T(), c.expected, err) + u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + if c.expected == nil { + if useTestOTP { + require.Equal(ts.T(), provider.SentMessages, 0) + } else { + require.Equal(ts.T(), provider.SentMessages, 1) + } + } + + switch c.otpType { + case phoneConfirmationOtp: + require.NotEmpty(ts.T(), u.ConfirmationToken) + require.NotEmpty(ts.T(), u.ConfirmationSentAt) + case phoneChangeVerification: + require.NotEmpty(ts.T(), u.PhoneChangeToken) + require.NotEmpty(ts.T(), u.PhoneChangeSentAt) + case phoneReauthenticationOtp: + require.NotEmpty(ts.T(), u.ReauthenticationToken) + require.NotEmpty(ts.T(), u.ReauthenticationSentAt) + default: + } + }) + } + // Reset at end of test + ts.API.config.Sms.TestOTP = nil + +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { + doTestSendPhoneConfirmation(ts, false) +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmationWithTestOTP() { + doTestSendPhoneConfirmation(ts, true) +} + +func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + now := time.Now() + u.PhoneConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + + cases := []struct { + desc string + endpoint string + method string + header string + body map[string]string + expected map[string]interface{} + }{ + { + desc: "Signup", + endpoint: "/signup", + method: http.MethodPost, + header: "", + body: map[string]string{ + "phone": "1234567890", + "password": "testpassword", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Sms OTP", + endpoint: "/otp", + method: http.MethodPost, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Phone change", + endpoint: "/user", + method: http.MethodPut, + header: token, + body: map[string]string{ + "phone": "111111111", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Reauthenticate", + endpoint: "/reauthenticate", + method: http.MethodGet, + header: "", + body: nil, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + } + + smsProviders := []string{"twilio", "messagebird", "textlocal", "vonage"} + ts.Config.External.Phone.Enabled = true + ts.Config.Sms.Twilio.AccountSid = "" + ts.Config.Sms.Messagebird.AccessKey = "" + ts.Config.Sms.Textlocal.ApiKey = "" + ts.Config.Sms.Vonage.ApiKey = "" + + for _, c := range cases { + for _, provider := range smsProviders { + ts.Config.Sms.Provider = provider + desc := fmt.Sprintf("[%v] %v", provider, c.desc) + ts.Run(desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req := httptest.NewRequest(c.method, "http://localhost"+c.endpoint, &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected["code"], w.Code) + + body := w.Body.String() + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) + }) + } + } +} +func (ts *PhoneTestSuite) TestSendSMSHook() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + now := time.Now() + u.PhoneConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + + // We setup a job table to enqueue SMS requests to send. Similar in spirit to the pg_boss postgres extension + createJobsTableSQL := `CREATE TABLE job_queue ( + id serial PRIMARY KEY, + job_type text, + payload jsonb, + status text DEFAULT 'pending', -- Possible values: 'pending', 'processing', 'completed', 'failed' + created_at timestamp without time zone DEFAULT NOW() + );` + require.NoError(ts.T(), ts.API.db.RawQuery(createJobsTableSQL).Exec()) + + type sendSMSHookTestCase struct { + desc string + uri string + endpoint string + method string + header string + body map[string]string + hookFunctionSQL string + expectedCode int + expectToken bool + hookFunctionIdentifier string + } + cases := []sendSMSHookTestCase{ + { + desc: "Phone signup using Hook", + endpoint: "/signup", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_signup", + hookFunctionSQL: ` + create or replace function send_sms_signup(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('sms_signup', input); + return input; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "1234567890", + "password": "testpassword", + }, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_signup(input jsonb)", + }, + { + desc: "SMS OTP sign in using hook", + endpoint: "/otp", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_otp", + hookFunctionSQL: ` + create or replace function send_sms_otp(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('sms_signup', input); + return input; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expectToken: false, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_otp(input jsonb)", + }, + { + desc: "Phone Change", + endpoint: "/user", + method: http.MethodPut, + uri: "pg-functions://postgres/auth/send_sms_phone_change", + hookFunctionSQL: ` + create or replace function send_sms_phone_change(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('phone_change', input); + return input; + end; $$ language plpgsql;`, + header: token, + body: map[string]string{ + "phone": "111111111", + }, + expectToken: true, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_phone_change(input jsonb)", + }, + { + desc: "Reauthenticate", + endpoint: "/reauthenticate", + method: http.MethodGet, + uri: "pg-functions://postgres/auth/reauthenticate", + hookFunctionSQL: ` + create or replace function reauthenticate(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`, + header: "", + body: nil, + expectToken: true, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "reauthenticate(input jsonb)", + }, + { + desc: "SMS OTP Hook (Error)", + endpoint: "/otp", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_otp_failure", + hookFunctionSQL: ` + create or replace function send_sms_otp(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expectToken: false, + expectedCode: http.StatusInternalServerError, + hookFunctionIdentifier: "send_sms_otp_failure(input jsonb)", + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + + ts.Config.External.Phone.Enabled = true + ts.Config.Hook.SendSMS.Enabled = true + ts.Config.Hook.SendSMS.URI = c.uri + // Disable FrequencyLimit to allow back to back sending + ts.Config.Sms.MaxFrequency = 0 * time.Second + require.NoError(ts.T(), ts.Config.Hook.SendSMS.PopulateExtensibilityPoint()) + + require.NoError(t, ts.API.db.RawQuery(c.hookFunctionSQL).Exec()) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + req := httptest.NewRequest(c.method, "http://localhost"+c.endpoint, &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(t, c.expectedCode, w.Code, "Unexpected HTTP status code") + + // Delete the function and reset env + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.SendSMS.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.SendSMS.Enabled = false + ts.Config.Sms.MaxFrequency = 1 * time.Second + }) + } + + // Cleanup + deleteJobsTableSQL := `drop table if exists job_queue` + require.NoError(ts.T(), ts.API.db.RawQuery(deleteJobsTableSQL).Exec()) + +} diff --git a/internal/api/pkce.go b/internal/api/pkce.go new file mode 100644 index 000000000..b4feeae5b --- /dev/null +++ b/internal/api/pkce.go @@ -0,0 +1,99 @@ +package api + +import ( + "regexp" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +const ( + PKCEPrefix = "pkce_" + MinCodeChallengeLength = 43 + MaxCodeChallengeLength = 128 + InvalidPKCEParamsErrorMessage = "PKCE flow requires code_challenge_method and code_challenge" +) + +var codeChallengePattern = regexp.MustCompile("^[a-zA-Z._~0-9-]+$") + +func isValidCodeChallenge(codeChallenge string) (bool, error) { + // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 + switch codeChallengeLength := len(codeChallenge); { + case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: + return false, badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + case !codeChallengePattern.MatchString(codeChallenge): + return false, badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + default: + return true, nil + } +} + +func addFlowPrefixToToken(token string, flowType models.FlowType) string { + if isPKCEFlow(flowType) { + return flowType.String() + "_" + token + } else if isImplicitFlow(flowType) { + return token + } + return token +} + +func issueAuthCode(tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod) (string, error) { + flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) + if err != nil && models.IsNotFoundError(err) { + return "", unprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "No valid flow state found for user.") + } else if err != nil { + return "", err + } + if err := flowState.RecordAuthCodeIssuedAtTime(tx); err != nil { + return "", err + } + + return flowState.AuthCode, nil +} + +func isPKCEFlow(flowType models.FlowType) bool { + return flowType == models.PKCEFlow +} + +func isImplicitFlow(flowType models.FlowType) bool { + return flowType == models.ImplicitFlow +} + +func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { + switch true { + case (codeChallenge == "") != (codeChallengeMethod == ""): + return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) + case codeChallenge != "": + if valid, err := isValidCodeChallenge(codeChallenge); !valid { + return err + } + default: + // if both params are empty, just return nil + return nil + } + return nil +} + +func getFlowFromChallenge(codeChallenge string) models.FlowType { + if codeChallenge != "" { + return models.PKCEFlow + } else { + return models.ImplicitFlow + } +} + +// Should only be used with Auth Code of PKCE Flows +func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { + codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam) + if err != nil { + return nil, err + } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID) + if err := tx.Create(flowState); err != nil { + return nil, err + } + return flowState, nil + +} diff --git a/internal/api/provider/apple.go b/internal/api/provider/apple.go new file mode 100644 index 000000000..508eaf137 --- /dev/null +++ b/internal/api/provider/apple.go @@ -0,0 +1,144 @@ +package provider + +import ( + "context" + "encoding/json" + "net/url" + "strconv" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const IssuerApple = "https://appleid.apple.com" + +// AppleProvider stores the custom config for apple provider +type AppleProvider struct { + *oauth2.Config + oidc *oidc.Provider +} + +type IsPrivateEmail bool + +// Apple returns an is_private_email field that could be a string or boolean value so we need to implement a custom unmarshaler +// https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/authenticating_users_with_sign_in_with_apple +func (b *IsPrivateEmail) UnmarshalJSON(data []byte) error { + var boolVal bool + if err := json.Unmarshal(data, &boolVal); err == nil { + *b = IsPrivateEmail(boolVal) + return nil + } + + // ignore the error and try to unmarshal as a string + var strVal string + if err := json.Unmarshal(data, &strVal); err != nil { + return err + } + + var err error + boolVal, err = strconv.ParseBool(strVal) + if err != nil { + return err + } + + *b = IsPrivateEmail(boolVal) + return nil +} + +type appleName struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` +} + +type appleUser struct { + Name appleName `json:"name"` + Email string `json:"email"` +} + +// NewAppleProvider creates a Apple account provider. +func NewAppleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + if ext.URL != "" { + logrus.Warn("Apple OAuth provider has URL config set which is ignored (check GOTRUE_EXTERNAL_APPLE_URL)") + } + + oidcProvider, err := oidc.NewProvider(ctx, IssuerApple) + if err != nil { + return nil, err + } + + return &AppleProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oidcProvider.Endpoint(), + Scopes: []string{ + "email", + "name", + }, + RedirectURL: ext.RedirectURI, + }, + oidc: oidcProvider, + }, nil +} + +// GetOAuthToken returns the apple provider access token +func (p AppleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("client_id", p.ClientID), + oauth2.SetAuthURLParam("secret", p.ClientSecret), + } + return p.Exchange(context.Background(), code, opts...) +} + +func (p AppleProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string { + opts := make([]oauth2.AuthCodeOption, 0, 1) + opts = append(opts, oauth2.SetAuthURLParam("response_mode", "form_post")) + authURL := p.Config.AuthCodeURL(state, opts...) + if authURL != "" { + if u, err := url.Parse(authURL); err != nil { + u.RawQuery = strings.ReplaceAll(u.RawQuery, "+", "%20") + authURL = u.String() + } + } + return authURL +} + +// GetUserData returns the user data fetched from the apple provider +func (p AppleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + // Apple returns user data only the first time + return &UserProvidedData{}, nil + } + + _, data, err := ParseIDToken(ctx, p.oidc, &oidc.Config{ + ClientID: p.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil +} + +// ParseUser parses the apple user's info +func (p AppleProvider) ParseUser(data string, userData *UserProvidedData) error { + u := &appleUser{} + err := json.Unmarshal([]byte(data), u) + if err != nil { + return err + } + + userData.Metadata.Name = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) + userData.Metadata.FullName = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) + return nil +} diff --git a/internal/api/provider/azure.go b/internal/api/provider/azure.go new file mode 100644 index 000000000..4a341f4d6 --- /dev/null +++ b/internal/api/provider/azure.go @@ -0,0 +1,164 @@ +package provider + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const IssuerAzureCommon = "https://login.microsoftonline.com/common/v2.0" +const IssuerAzureOrganizations = "https://login.microsoftonline.com/organizations/v2.0" + +// IssuerAzureMicrosoft is the OIDC issuer for microsoft.com accounts: +// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims +const IssuerAzureMicrosoft = "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0" + +const ( + defaultAzureAuthBase = "login.microsoftonline.com/common" +) + +type azureProvider struct { + *oauth2.Config + + // ExpectedIssuer contains the OIDC issuer that should be expected when + // the authorize flow completes. For example, when using the "common" + // endpoint the authorization flow will end with an ID token that + // contains any issuer. In this case, ExpectedIssuer is an empty + // string, because any issuer is allowed. But if a developer sets up a + // tenant-specific authorization endpoint, then we must ensure that the + // ID token received is issued by that specific issuer, and so + // ExpectedIssuer contains the issuer URL of that tenant. + ExpectedIssuer string +} + +var azureIssuerRegexp = regexp.MustCompile("^https://login[.]microsoftonline[.]com/([^/]+)/v2[.]0/?$") +var azureCIAMIssuerRegexp = regexp.MustCompile("^https://[a-z0-9-]+[.]ciamlogin[.]com/([^/]+)/v2[.]0/?$") + +func IsAzureIssuer(issuer string) bool { + return azureIssuerRegexp.MatchString(issuer) +} + +func IsAzureCIAMIssuer(issuer string) bool { + return azureCIAMIssuerRegexp.MatchString(issuer) +} + +// NewAzureProvider creates a Azure account provider. +func NewAzureProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := []string{"openid"} + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + authHost := chooseHost(ext.URL, defaultAzureAuthBase) + expectedIssuer := "" + + if ext.URL != "" { + expectedIssuer = authHost + "/v2.0" + + if !IsAzureIssuer(expectedIssuer) || !IsAzureCIAMIssuer(expectedIssuer) || expectedIssuer == IssuerAzureCommon || expectedIssuer == IssuerAzureOrganizations { + // in tests, the URL is a local server which should not + // be the expected issuer + // also, IssuerAzure (common) never actually issues any + // ID tokens so it needs to be ignored + expectedIssuer = "" + } + } + + return &azureProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth2/v2.0/authorize", + TokenURL: authHost + "/oauth2/v2.0/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + ExpectedIssuer: expectedIssuer, + }, nil +} + +func (g azureProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func DetectAzureIDTokenIssuer(ctx context.Context, idToken string) (string, error) { + var payload struct { + Issuer string `json:"iss"` + } + + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + return "", fmt.Errorf("azure: invalid ID token") + } + + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("azure: invalid ID token %w", err) + } + + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return "", fmt.Errorf("azure: invalid ID token %w", err) + } + + return payload.Issuer, nil +} + +func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + + if idToken != nil { + issuer, err := DetectAzureIDTokenIssuer(ctx, idToken.(string)) + if err != nil { + return nil, err + } + + // Allow basic Azure issuers, except when the expected issuer + // is configured to be the Azure CIAM issuer, allow CIAM + // issuers to pass. + if !IsAzureIssuer(issuer) && (IsAzureCIAMIssuer(g.ExpectedIssuer) && !IsAzureCIAMIssuer(issuer)) { + return nil, fmt.Errorf("azure: ID token issuer not valid %q", issuer) + } + + if g.ExpectedIssuer != "" && issuer != g.ExpectedIssuer { + // Since ExpectedIssuer was set, then the developer had + // setup GoTrue to use the tenant-specific + // authorization endpoint, which in-turn means that + // only those tenant's ID tokens will be accepted. + return nil, fmt.Errorf("azure: ID token issuer %q does not match expected issuer %q", issuer, g.ExpectedIssuer) + } + + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, err + } + + _, data, err := ParseIDToken(ctx, provider, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil + } + + // Only ID tokens supported, UserInfo endpoint has a history of being less secure. + + return nil, fmt.Errorf("azure: no OIDC ID token present in response") +} diff --git a/internal/api/provider/azure_test.go b/internal/api/provider/azure_test.go new file mode 100644 index 000000000..316cb08ba --- /dev/null +++ b/internal/api/provider/azure_test.go @@ -0,0 +1,29 @@ +package provider + +import "testing" + +func TestIsAzureIssuer(t *testing.T) { + positiveExamples := []string{ + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0/", + "https://login.microsoftonline.com/common/v2.0", + } + + negativeExamples := []string{ + "http://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0?something=else", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0/extra", + } + + for _, example := range positiveExamples { + if !IsAzureIssuer(example) { + t.Errorf("Example %q should be treated as a valid Azure issuer", example) + } + } + + for _, example := range negativeExamples { + if IsAzureIssuer(example) { + t.Errorf("Example %q should be treated as not a valid Azure issuer", example) + } + } +} diff --git a/api/provider/bitbucket.go b/internal/api/provider/bitbucket.go similarity index 89% rename from api/provider/bitbucket.go rename to internal/api/provider/bitbucket.go index 6fb0f7330..e5fae5c91 100644 --- a/api/provider/bitbucket.go +++ b/internal/api/provider/bitbucket.go @@ -2,9 +2,8 @@ package provider import ( "context" - "errors" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -38,7 +37,7 @@ type bitbucketEmails struct { // NewBitbucketProvider creates a Bitbucket account provider. func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -47,7 +46,7 @@ func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, e return &bitbucketProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authHost + "/site/oauth2/authorize", @@ -61,7 +60,7 @@ func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, e } func (g bitbucketProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -89,10 +88,6 @@ func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) ( } } - if len(data.Emails) <= 0 { - return nil, errors.New("Unable to find email with Bitbucket provider") - } - data.Metadata = &Claims{ Issuer: g.APIPath, Subject: u.ID, diff --git a/api/provider/discord.go b/internal/api/provider/discord.go similarity index 63% rename from api/provider/discord.go rename to internal/api/provider/discord.go index b323abfc9..50d413b7c 100644 --- a/api/provider/discord.go +++ b/internal/api/provider/discord.go @@ -2,11 +2,11 @@ package provider import ( "context" - "errors" "fmt" + "strconv" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -21,16 +21,17 @@ type discordProvider struct { type discordUser struct { Avatar string `json:"avatar"` - Discriminator int `json:"discriminator,string"` + Discriminator string `json:"discriminator"` Email string `json:"email"` ID string `json:"id"` Name string `json:"username"` + GlobalName string `json:"global_name"` Verified bool `json:"verified"` } // NewDiscordProvider creates a Discord account provider. func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -47,7 +48,7 @@ func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu return &discordProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: apiPath + "/oauth2/authorize", @@ -61,7 +62,7 @@ func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu } func (g discordProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -70,17 +71,26 @@ func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*U return nil, err } - if u.Email == "" { - return nil, errors.New("Unable to find email with Discord provider") + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: u.Verified, + Primary: true, + }} } var avatarURL string extension := "png" - // https://discord.com/developers/docs/reference#image-formatting-cdn-endpoints: - // In the case of the Default User Avatar endpoint, the value for - // user_discriminator in the path should be the user's discriminator modulo 5 if u.Avatar == "" { - avatarURL = fmt.Sprintf("https://cdn.discordapp.com/embed/avatars/%d.%s", u.Discriminator%5, extension) + if intDiscriminator, err := strconv.Atoi(u.Discriminator); err != nil { + return nil, err + } else { + // https://discord.com/developers/docs/reference#image-formatting-cdn-endpoints: + // In the case of the Default User Avatar endpoint, the value for + // user_discriminator in the path should be the user's discriminator modulo 5 + avatarURL = fmt.Sprintf("https://cdn.discordapp.com/embed/avatars/%d.%s", intDiscriminator%5, extension) + } } else { // https://discord.com/developers/docs/reference#image-formatting: // "In the case of endpoints that support GIFs, the hash will begin with a_ @@ -91,24 +101,20 @@ func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*U avatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.%s", u.ID, u.Avatar, extension) } - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.ID, - Name: u.Name, - Picture: avatarURL, - Email: u.Email, - EmailVerified: u.Verified, - - // To be deprecated - AvatarURL: avatarURL, - FullName: u.Name, - ProviderId: u.ID, + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: fmt.Sprintf("%v#%v", u.Name, u.Discriminator), + Picture: avatarURL, + CustomClaims: map[string]interface{}{ + "global_name": u.GlobalName, }, - Emails: []Email{{ - Email: u.Email, - Verified: u.Verified, - Primary: true, - }}, - }, nil + + // To be deprecated + AvatarURL: avatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + + return data, nil } diff --git a/internal/api/provider/errors.go b/internal/api/provider/errors.go new file mode 100644 index 000000000..67a20eaad --- /dev/null +++ b/internal/api/provider/errors.go @@ -0,0 +1,49 @@ +package provider + +import "fmt" + +type HTTPError struct { + Code int `json:"code"` + Message string `json:"msg"` + InternalError error `json:"-"` + InternalMessage string `json:"-"` + ErrorID string `json:"error_id,omitempty"` +} + +func (e *HTTPError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + +func (e *HTTPError) Is(target error) bool { + return e.Error() == target.Error() +} + +// Cause returns the root cause error +func (e *HTTPError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// WithInternalError adds internal error information to the error +func (e *HTTPError) WithInternalError(err error) *HTTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + +func httpError(code int, fmtString string, args ...interface{}) *HTTPError { + return &HTTPError{ + Code: code, + Message: fmt.Sprintf(fmtString, args...), + } +} diff --git a/api/provider/facebook.go b/internal/api/provider/facebook.go similarity index 70% rename from api/provider/facebook.go rename to internal/api/provider/facebook.go index 156552b27..e73c419da 100644 --- a/api/provider/facebook.go +++ b/internal/api/provider/facebook.go @@ -5,16 +5,17 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" - "errors" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) +const IssuerFacebook = "https://www.facebook.com" + const ( defaultFacebookAuthBase = "www.facebook.com" - defaultFacebookTokenBase = "graph.facebook.com" + defaultFacebookTokenBase = "graph.facebook.com" //#nosec G101 -- Not a secret value. defaultFacebookAPIBase = "graph.facebook.com" ) @@ -38,7 +39,7 @@ type facebookUser struct { // NewFacebookProvider creates a Facebook account provider. func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -56,7 +57,7 @@ func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA return &facebookProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, RedirectURL: ext.RedirectURI, Endpoint: oauth2.Endpoint{ @@ -70,7 +71,7 @@ func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA } func (p facebookProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(oauth2.NoContext, code) + return p.Exchange(context.Background(), code) } func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -84,30 +85,28 @@ func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (* return nil, err } - if u.Email == "" { - return nil, errors.New("Unable to find email with Facebook provider") - } - - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: p.ProfileURL, - Subject: u.ID, - Name: strings.TrimSpace(u.FirstName + " " + u.LastName), - NickName: u.Alias, - Email: u.Email, - EmailVerified: true, // if email is returned, the email is verified by facebook already - Picture: u.Avatar.Data.URL, - - // To be deprecated - Slug: u.Alias, - AvatarURL: u.Avatar.Data.URL, - FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), - ProviderId: u.ID, - }, - Emails: []Email{{ + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ Email: u.Email, Verified: true, Primary: true, - }}, - }, nil + }} + } + + data.Metadata = &Claims{ + Issuer: p.ProfileURL, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + NickName: u.Alias, + Picture: u.Avatar.Data.URL, + + // To be deprecated + Slug: u.Alias, + AvatarURL: u.Avatar.Data.URL, + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + + return data, nil } diff --git a/internal/api/provider/figma.go b/internal/api/provider/figma.go new file mode 100644 index 000000000..aae74939b --- /dev/null +++ b/internal/api/provider/figma.go @@ -0,0 +1,95 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Figma +// Reference: https://www.figma.com/developers/api#oauth2 + +const ( + defaultFigmaAuthBase = "www.figma.com" + defaultFigmaAPIBase = "api.figma.com" +) + +type figmaProvider struct { + *oauth2.Config + APIHost string +} + +type figmaUser struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"handle"` + AvatarURL string `json:"img_url"` +} + +// NewFigmaProvider creates a Figma account provider. +func NewFigmaProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultFigmaAuthBase) + apiHost := chooseHost(ext.URL, defaultFigmaAPIBase) + + oauthScopes := []string{ + "files:read", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &figmaProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth", + TokenURL: apiHost + "/v1/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} + +func (p figmaProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p figmaProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u figmaUser + if err := makeRequest(ctx, tok, p.Config, p.APIHost+"/v1/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: p.APIHost, + Subject: u.ID, + Name: u.Name, + Email: u.Email, + EmailVerified: true, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil +} diff --git a/internal/api/provider/fly.go b/internal/api/provider/fly.go new file mode 100644 index 000000000..d9337524f --- /dev/null +++ b/internal/api/provider/fly.go @@ -0,0 +1,103 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultFlyAPIBase = "api.fly.io" +) + +type flyProvider struct { + *oauth2.Config + APIPath string +} + +type flyUser struct { + ResourceOwnerID string `json:"resource_owner_id"` + UserID string `json:"user_id"` + UserName string `json:"user_name"` + Email string `json:"email"` + Organizations []struct { + ID string `json:"id"` + Role string `json:"role"` + } `json:"organizations"` + Scope []string `json:"scope"` + Application map[string]string `json:"application"` + ExpiresIn int `json:"expires_in"` + CreatedAt int `json:"created_at"` +} + +// NewFlyProvider creates a Fly oauth provider. +func NewFlyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultFlyAPIBase) + + // Fly only provides the "read" scope. + // https://fly.io/docs/reference/extensions_api/#single-sign-on-flow + oauthScopes := []string{ + "read", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &flyProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth/authorize", + TokenURL: authHost + "/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIPath: authHost, + }, nil +} + +func (p flyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p flyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u flyUser + if err := makeRequest(ctx, tok, p.Config, p.APIPath+"/oauth/token/info", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: p.APIPath, + Subject: u.UserID, + FullName: u.UserName, + Email: u.Email, + EmailVerified: true, + ProviderId: u.UserID, + CustomClaims: map[string]interface{}{ + "resource_owner_id": u.ResourceOwnerID, + "organizations": u.Organizations, + "application": u.Application, + "scope": u.Scope, + "created_at": u.CreatedAt, + }, + } + return data, nil +} diff --git a/api/provider/github.go b/internal/api/provider/github.go similarity index 86% rename from api/provider/github.go rename to internal/api/provider/github.go index cba57f108..0da3e8842 100644 --- a/api/provider/github.go +++ b/internal/api/provider/github.go @@ -2,11 +2,10 @@ package provider import ( "context" - "errors" "strconv" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -38,7 +37,7 @@ type githubUserEmail struct { // NewGithubProvider creates a Github account provider. func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -58,7 +57,7 @@ func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut return &githubProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authHost + "/login/oauth/authorize", @@ -72,7 +71,7 @@ func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut } func (g githubProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -105,15 +104,6 @@ func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us if e.Email != "" { data.Emails = append(data.Emails, Email{Email: e.Email, Verified: e.Verified, Primary: e.Primary}) } - - if e.Primary { - data.Metadata.Email = e.Email - data.Metadata.EmailVerified = e.Verified - } - } - - if len(data.Emails) <= 0 { - return nil, errors.New("Unable to find email with GitHub provider") } return data, nil diff --git a/api/provider/gitlab.go b/internal/api/provider/gitlab.go similarity index 83% rename from api/provider/gitlab.go rename to internal/api/provider/gitlab.go index 1d5c78240..4b5d70cb8 100644 --- a/api/provider/gitlab.go +++ b/internal/api/provider/gitlab.go @@ -2,11 +2,10 @@ package provider import ( "context" - "errors" "strconv" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -34,7 +33,7 @@ type gitlabUserEmail struct { // NewGitlabProvider creates a Gitlab account provider. func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -49,7 +48,7 @@ func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut host := chooseHost(ext.URL, defaultGitLabAuthBase) return &gitlabProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: host + "/oauth/authorize", @@ -63,7 +62,7 @@ func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut } func (g gitlabProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -92,17 +91,11 @@ func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us data.Emails = append(data.Emails, Email{Email: u.Email, Verified: verified, Primary: true}) } - if len(data.Emails) <= 0 { - return nil, errors.New("Unable to find email with GitLab provider") - } - data.Metadata = &Claims{ - Issuer: g.Host, - Subject: strconv.Itoa(u.ID), - Name: u.Name, - Picture: u.AvatarURL, - Email: u.Email, - EmailVerified: true, + Issuer: g.Host, + Subject: strconv.Itoa(u.ID), + Name: u.Name, + Picture: u.AvatarURL, // To be deprecated AvatarURL: u.AvatarURL, diff --git a/internal/api/provider/google.go b/internal/api/provider/google.go new file mode 100644 index 000000000..03b76aebe --- /dev/null +++ b/internal/api/provider/google.go @@ -0,0 +1,144 @@ +package provider + +import ( + "context" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +type googleUser struct { + ID string `json:"id"` + Subject string `json:"sub"` + Issuer string `json:"iss"` + Name string `json:"name"` + AvatarURL string `json:"picture"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email"` + EmailVerified bool `json:"email_verified"` + HostedDomain string `json:"hd"` +} + +func (u googleUser) IsEmailVerified() bool { + return u.VerifiedEmail || u.EmailVerified +} + +const IssuerGoogle = "https://accounts.google.com" + +var internalIssuerGoogle = IssuerGoogle + +type googleProvider struct { + *oauth2.Config + + oidc *oidc.Provider +} + +// NewGoogleProvider creates a Google OAuth2 identity provider. +func NewGoogleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + if ext.URL != "" { + logrus.Warn("Google OAuth provider has URL config set which is ignored (check GOTRUE_EXTERNAL_GOOGLE_URL)") + } + + oauthScopes := []string{ + "email", + "profile", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(ctx, internalIssuerGoogle) + if err != nil { + return nil, err + } + + return &googleProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oidcProvider.Endpoint(), + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + oidc: oidcProvider, + }, nil +} + +func (g googleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +const UserInfoEndpointGoogle = "https://www.googleapis.com/userinfo/v2/me" + +var internalUserInfoEndpointGoogle = UserInfoEndpointGoogle + +func (g googleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + if idToken := tok.Extra("id_token"); idToken != nil { + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.Config.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil + } + + // This whole section offers legacy support in case the Google OAuth2 + // flow does not return an ID Token for the user, which appears to + // always be the case. + logrus.Info("Using Google OAuth2 user info endpoint, an ID token was not returned by Google") + + var u googleUser + if err := makeRequest(ctx, tok, g.Config, internalUserInfoEndpointGoogle, &u); err != nil { + return nil, err + } + + var data UserProvidedData + + if u.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: u.Email, + Verified: u.IsEmailVerified(), + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: internalUserInfoEndpointGoogle, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + Email: u.Email, + EmailVerified: u.IsEmailVerified(), + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + + return &data, nil +} + +// ResetGoogleProvider should only be used in tests! +func ResetGoogleProvider() { + internalIssuerGoogle = IssuerGoogle + internalUserInfoEndpointGoogle = UserInfoEndpointGoogle +} + +// OverrideGoogleProvider should only be used in tests! +func OverrideGoogleProvider(issuer, userInfo string) { + internalIssuerGoogle = issuer + internalUserInfoEndpointGoogle = userInfo +} diff --git a/internal/api/provider/kakao.go b/internal/api/provider/kakao.go new file mode 100644 index 000000000..2482b97a8 --- /dev/null +++ b/internal/api/provider/kakao.go @@ -0,0 +1,107 @@ +package provider + +import ( + "context" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultKakaoAuthBase = "kauth.kakao.com" + defaultKakaoAPIBase = "kapi.kakao.com" + IssuerKakao = "https://kauth.kakao.com" +) + +type kakaoProvider struct { + *oauth2.Config + APIHost string +} + +type kakaoUser struct { + ID int `json:"id"` + Account struct { + Profile struct { + Name string `json:"nickname"` + ProfileImageURL string `json:"profile_image_url"` + } `json:"profile"` + Email string `json:"email"` + EmailValid bool `json:"is_email_valid"` + EmailVerified bool `json:"is_email_verified"` + } `json:"kakao_account"` +} + +func (p kakaoProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p kakaoProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u kakaoUser + + if err := makeRequest(ctx, tok, p.Config, p.APIHost+"/v2/user/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + + if u.Account.Email != "" { + data.Emails = []Email{ + { + Email: u.Account.Email, + Verified: u.Account.EmailVerified && u.Account.EmailValid, + Primary: true, + }, + } + } + + data.Metadata = &Claims{ + Issuer: p.APIHost, + Subject: strconv.Itoa(u.ID), + + Name: u.Account.Profile.Name, + PreferredUsername: u.Account.Profile.Name, + + // To be deprecated + AvatarURL: u.Account.Profile.ProfileImageURL, + FullName: u.Account.Profile.Name, + ProviderId: strconv.Itoa(u.ID), + UserNameKey: u.Account.Profile.Name, + } + return data, nil +} + +func NewKakaoProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultKakaoAuthBase) + apiHost := chooseHost(ext.URL, defaultKakaoAPIBase) + + oauthScopes := []string{ + "account_email", + "profile_image", + "profile_nickname", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &kakaoProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleInParams, + AuthURL: authHost + "/oauth/authorize", + TokenURL: authHost + "/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} diff --git a/internal/api/provider/keycloak.go b/internal/api/provider/keycloak.go new file mode 100644 index 000000000..39ccec5bf --- /dev/null +++ b/internal/api/provider/keycloak.go @@ -0,0 +1,98 @@ +package provider + +import ( + "context" + "errors" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Keycloak +type keycloakProvider struct { + *oauth2.Config + Host string +} + +type keycloakUser struct { + Name string `json:"name"` + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` +} + +// NewKeycloakProvider creates a Keycloak account provider. +func NewKeycloakProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := []string{ + "profile", + "email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + if ext.URL == "" { + return nil, errors.New("unable to find URL for the Keycloak provider") + } + + extURLlen := len(ext.URL) + if ext.URL[extURLlen-1] == '/' { + ext.URL = ext.URL[:extURLlen-1] + } + + return &keycloakProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: ext.URL + "/protocol/openid-connect/auth", + TokenURL: ext.URL + "/protocol/openid-connect/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + Host: ext.URL, + }, nil +} + +func (g keycloakProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g keycloakProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u keycloakUser + + if err := makeRequest(ctx, tok, g.Config, g.Host+"/protocol/openid-connect/userinfo", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: u.EmailVerified, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.Host, + Subject: u.Sub, + Name: u.Name, + Email: u.Email, + EmailVerified: u.EmailVerified, + + // To be deprecated + FullName: u.Name, + ProviderId: u.Sub, + } + + return data, nil + +} diff --git a/api/provider/linkedin.go b/internal/api/provider/linkedin.go similarity index 78% rename from api/provider/linkedin.go rename to internal/api/provider/linkedin.go index 331faf45c..bc33515e7 100644 --- a/api/provider/linkedin.go +++ b/internal/api/provider/linkedin.go @@ -2,10 +2,9 @@ package provider import ( "context" - "errors" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -37,6 +36,14 @@ type linkedinUser struct { } `json:"profilePicture"` } +func (u *linkedinUser) getAvatarUrl() string { + avatarURL := "" + if len(u.AvatarURL.DisplayImage.Elements) > 0 { + avatarURL = u.AvatarURL.DisplayImage.Elements[0].Identifiers[0].Identifier + } + return avatarURL +} + type linkedinName struct { Localized interface{} `json:"localized"` PreferredLocale linkedinLocale `json:"preferredLocale"` @@ -60,7 +67,7 @@ type linkedinElements struct { // NewLinkedinProvider creates a Linkedin account provider. func NewLinkedinProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -77,7 +84,7 @@ func NewLinkedinProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA return &linkedinProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: apiPath + "/oauth/v2/authorization", @@ -91,7 +98,7 @@ func NewLinkedinProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA } func (g linkedinProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func GetName(name linkedinName) string { @@ -112,35 +119,31 @@ func (g linkedinProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (* return nil, err } - if len(e.Elements) <= 0 { - return nil, errors.New("Unable to find email with Linkedin provider") - } - - emails := []Email{} + data := &UserProvidedData{} if e.Elements[0].HandleTilde.EmailAddress != "" { // linkedin only returns the primary email which is verified for the r_emailaddress scope. - emails = append(emails, Email{ + data.Emails = []Email{{ Email: e.Elements[0].HandleTilde.EmailAddress, Primary: true, Verified: true, - }) + }} } - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.ID, - Name: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), - Picture: u.AvatarURL.DisplayImage.Elements[0].Identifiers[0].Identifier, - Email: e.Elements[0].HandleTilde.EmailAddress, - EmailVerified: true, - - // To be deprecated - AvatarURL: u.AvatarURL.DisplayImage.Elements[0].Identifiers[0].Identifier, - FullName: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), - ProviderId: u.ID, - }, - Emails: emails, - }, nil + avatarURL := u.getAvatarUrl() + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), + Picture: avatarURL, + Email: e.Elements[0].HandleTilde.EmailAddress, + EmailVerified: true, + + // To be deprecated + AvatarURL: avatarURL, + FullName: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), + ProviderId: u.ID, + } + return data, nil } diff --git a/internal/api/provider/linkedin_oidc.go b/internal/api/provider/linkedin_oidc.go new file mode 100644 index 000000000..a5d94fa09 --- /dev/null +++ b/internal/api/provider/linkedin_oidc.go @@ -0,0 +1,81 @@ +package provider + +import ( + "context" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultLinkedinOIDCAPIBase = "api.linkedin.com" + IssuerLinkedin = "https://www.linkedin.com/oauth" +) + +type linkedinOIDCProvider struct { + *oauth2.Config + oidc *oidc.Provider + APIPath string +} + +// NewLinkedinOIDCProvider creates a Linkedin account provider via OIDC. +func NewLinkedinOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultLinkedinOIDCAPIBase) + + oauthScopes := []string{ + "openid", + "email", + "profile", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(context.Background(), IssuerLinkedin) + if err != nil { + return nil, err + } + + return &linkedinOIDCProvider{ + oidc: oidcProvider, + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth/v2/authorization", + TokenURL: apiPath + "/oauth/v2/accessToken", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g linkedinOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g linkedinOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + return &UserProvidedData{}, nil + } + + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/api/provider/notion.go b/internal/api/provider/notion.go similarity index 62% rename from api/provider/notion.go rename to internal/api/provider/notion.go index 5971b308c..f8d0ee706 100644 --- a/api/provider/notion.go +++ b/internal/api/provider/notion.go @@ -3,11 +3,12 @@ package provider import ( "context" "encoding/json" - "errors" - "io/ioutil" + "fmt" + "io" "net/http" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" "golang.org/x/oauth2" ) @@ -38,7 +39,7 @@ type notionUser struct { // NewNotionProvider creates a Notion account provider. func NewNotionProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -46,7 +47,7 @@ func NewNotionProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, erro return ¬ionProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authHost + "/v1/oauth/authorize", @@ -59,7 +60,7 @@ func NewNotionProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, erro } func (g notionProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g notionProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -76,39 +77,45 @@ func (g notionProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us req.Header.Set("Notion-Version", notionApiVersion) req.Header.Set("Authorization", "Bearer "+tok.AccessToken) - client := &http.Client{} + client := &http.Client{Timeout: defaultTimeout} resp, err := client.Do(req) - if err != nil { return nil, err } + defer utilities.SafeClose(resp.Body) - body, _ := ioutil.ReadAll(resp.Body) - json.Unmarshal(body, &u) - defer resp.Body.Close() + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("a %v error occurred with retrieving user from notion", resp.StatusCode) + } - if u.Bot.Owner.User.Person.Email == "" { - return nil, errors.New("unable to find email with notion provider") + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(body, &u) + if err != nil { + return nil, err } - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.Bot.Owner.User.ID, - Name: u.Bot.Owner.User.Name, - Picture: u.Bot.Owner.User.AvatarURL, - Email: u.Bot.Owner.User.Person.Email, - EmailVerified: true, // Notion dosen't provide data on if email is verified. - - // To be deprecated - AvatarURL: u.Bot.Owner.User.AvatarURL, - FullName: u.Bot.Owner.User.Name, - ProviderId: u.Bot.Owner.User.ID, - }, - Emails: []Email{{ + data := &UserProvidedData{} + if u.Bot.Owner.User.Person.Email != "" { + data.Emails = []Email{{ Email: u.Bot.Owner.User.Person.Email, Verified: true, // Notion dosen't provide data on if email is verified. Primary: true, - }}, - }, nil + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.Bot.Owner.User.ID, + Name: u.Bot.Owner.User.Name, + Picture: u.Bot.Owner.User.AvatarURL, + + // To be deprecated + AvatarURL: u.Bot.Owner.User.AvatarURL, + FullName: u.Bot.Owner.User.Name, + ProviderId: u.Bot.Owner.User.ID, + } + return data, nil } diff --git a/internal/api/provider/oidc.go b/internal/api/provider/oidc.go new file mode 100644 index 000000000..51c88e639 --- /dev/null +++ b/internal/api/provider/oidc.go @@ -0,0 +1,410 @@ +package provider + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v5" +) + +type ParseIDTokenOptions struct { + SkipAccessTokenCheck bool + AccessToken string +} + +// OverrideVerifiers can be used to set a custom verifier for an OIDC provider +// (identified by the provider's Endpoint().AuthURL string). Should only be +// used in tests. +var OverrideVerifiers = make(map[string]func(context.Context, *oidc.Config) *oidc.IDTokenVerifier) + +// OverrideClock can be used to set a custom clock function to be used when +// parsing ID tokens. Should only be used in tests. +var OverrideClock func() time.Time + +func ParseIDToken(ctx context.Context, provider *oidc.Provider, config *oidc.Config, idToken string, options ParseIDTokenOptions) (*oidc.IDToken, *UserProvidedData, error) { + if config == nil { + config = &oidc.Config{ + // aud claim check to be performed by other flows + SkipClientIDCheck: true, + } + } + + if OverrideClock != nil { + clonedConfig := *config + clonedConfig.Now = OverrideClock + config = &clonedConfig + } + + verifier := provider.VerifierContext(ctx, config) + overrideVerifier, ok := OverrideVerifiers[provider.Endpoint().AuthURL] + if ok && overrideVerifier != nil { + verifier = overrideVerifier(ctx, config) + } + + token, err := verifier.Verify(ctx, idToken) + if err != nil { + return nil, nil, err + } + + var data *UserProvidedData + + switch token.Issuer { + case IssuerGoogle: + token, data, err = parseGoogleIDToken(token) + case IssuerApple: + token, data, err = parseAppleIDToken(token) + case IssuerLinkedin: + token, data, err = parseLinkedinIDToken(token) + case IssuerKakao: + token, data, err = parseKakaoIDToken(token) + case IssuerVercelMarketplace: + token, data, err = parseVercelMarketplaceIDToken(token) + default: + if IsAzureIssuer(token.Issuer) { + token, data, err = parseAzureIDToken(token) + } else { + token, data, err = parseGenericIDToken(token) + } + } + + if err != nil { + return nil, nil, err + } + + if !options.SkipAccessTokenCheck && token.AccessTokenHash != "" { + if err := token.VerifyAccessToken(options.AccessToken); err != nil { + return nil, nil, err + } + } + + return token, data, nil +} + +func parseGoogleIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims googleUser + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: claims.IsEmailVerified(), + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: claims.Issuer, + Subject: claims.Subject, + Name: claims.Name, + Picture: claims.AvatarURL, + + // To be deprecated + AvatarURL: claims.AvatarURL, + FullName: claims.Name, + ProviderId: claims.Subject, + } + + if claims.HostedDomain != "" { + data.Metadata.CustomClaims = map[string]any{ + "hd": claims.HostedDomain, + } + } + + return token, &data, nil +} + +type AppleIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + + AuthTime *float64 `json:"auth_time"` + IsPrivateEmail *IsPrivateEmail `json:"is_private_email"` +} + +func parseAppleIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims AppleIDTokenClaims + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: true, + Primary: true, + }) + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + CustomClaims: make(map[string]any), + } + + if claims.IsPrivateEmail != nil { + data.Metadata.CustomClaims["is_private_email"] = *claims.IsPrivateEmail + } + + if claims.AuthTime != nil { + data.Metadata.CustomClaims["auth_time"] = *claims.AuthTime + } + + if len(data.Metadata.CustomClaims) < 1 { + data.Metadata.CustomClaims = nil + } + + return token, &data, nil +} + +type LinkedinIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + EmailVerified string `json:"email_verified"` + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + Locale string `json:"locale"` + Picture string `json:"picture"` +} + +func parseLinkedinIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims LinkedinIDTokenClaims + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + emailVerified, err := strconv.ParseBool(claims.EmailVerified) + if err != nil { + return nil, nil, err + } + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: emailVerified, + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + Name: strings.TrimSpace(claims.GivenName + " " + claims.FamilyName), + GivenName: claims.GivenName, + FamilyName: claims.FamilyName, + Locale: claims.Locale, + Picture: claims.Picture, + ProviderId: token.Subject, + } + + return token, &data, nil +} + +type AzureIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"` +} + +func (c *AzureIDTokenClaims) IsEmailVerified() bool { + emailVerified := false + + edov := c.XMicrosoftEmailDomainOwnerVerified + + // If xms_edov is not set, and an email is present or xms_edov is true, + // only then is the email regarded as verified. + // https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users + if edov == nil { + // An email is provided, but xms_edov is not -- probably not + // configured, so we must assume the email is verified as Azure + // will only send out a potentially unverified email address in + // single-tenanat apps. + emailVerified = c.Email != "" + } else { + edovBool := false + + // Azure can't be trusted with how they encode the xms_edov + // claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true. + switch v := edov.(type) { + case bool: + edovBool = v + + case string: + edovBool = v == "1" || v == "true" + + default: + edovBool = false + } + + emailVerified = c.Email != "" && edovBool + } + + return emailVerified +} + +// removeAzureClaimsFromCustomClaims contains the list of claims to be removed +// from the CustomClaims map. See: +// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference +var removeAzureClaimsFromCustomClaims = []string{ + "aud", + "iss", + "iat", + "nbf", + "exp", + "c_hash", + "at_hash", + "aio", + "nonce", + "rh", + "uti", + "jti", + "ver", + "sub", + "name", + "preferred_username", +} + +func parseAzureIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var data UserProvidedData + + var azureClaims AzureIDTokenClaims + if err := token.Claims(&azureClaims); err != nil { + return nil, nil, err + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + PreferredUsername: azureClaims.PreferredUsername, + FullName: azureClaims.Name, + CustomClaims: make(map[string]any), + } + + if azureClaims.Email != "" { + data.Emails = []Email{{ + Email: azureClaims.Email, + Verified: azureClaims.IsEmailVerified(), + Primary: true, + }} + } + + if err := token.Claims(&data.Metadata.CustomClaims); err != nil { + return nil, nil, err + } + + if data.Metadata.CustomClaims != nil { + for _, claim := range removeAzureClaimsFromCustomClaims { + delete(data.Metadata.CustomClaims, claim) + } + } + + return token, &data, nil +} + +type KakaoIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + Nickname string `json:"nickname"` + Picture string `json:"picture"` +} + +func parseKakaoIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims KakaoIDTokenClaims + + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: true, + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + Name: claims.Nickname, + PreferredUsername: claims.Nickname, + ProviderId: token.Subject, + Picture: claims.Picture, + } + + return token, &data, nil +} + +type VercelMarketplaceIDTokenClaims struct { + jwt.RegisteredClaims + + UserEmail string `json:"user_email"` + UserName string `json:"user_name"` + UserAvatarUrl string `json:"user_avatar_url"` +} + +func parseVercelMarketplaceIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims VercelMarketplaceIDTokenClaims + + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + data.Emails = append(data.Emails, Email{ + Email: claims.UserEmail, + Verified: true, + Primary: true, + }) + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + Name: claims.UserName, + Picture: claims.UserAvatarUrl, + } + + return token, &data, nil +} + +func parseGenericIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var data UserProvidedData + + if err := token.Claims(&data.Metadata); err != nil { + return nil, nil, err + } + + if data.Metadata.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: data.Metadata.Email, + Verified: data.Metadata.EmailVerified, + Primary: true, + }) + } + + if len(data.Emails) <= 0 { + return nil, nil, fmt.Errorf("provider: Generic OIDC ID token from issuer %q must contain an email address", token.Issuer) + } + + return token, &data, nil +} diff --git a/internal/api/provider/oidc_test.go b/internal/api/provider/oidc_test.go new file mode 100644 index 000000000..e088cd45f --- /dev/null +++ b/internal/api/provider/oidc_test.go @@ -0,0 +1,185 @@ +package provider + +import ( + "context" + "crypto" + "crypto/rsa" + "encoding/base64" + "math/big" + "testing" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/require" +) + +type realIDToken struct { + AccessToken string + IDToken string + Time time.Time + Email string + Verifier func(context.Context, *oidc.Config) *oidc.IDTokenVerifier +} + +func googleIDTokenVerifier(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + keyBytes, err := base64.RawURLEncoding.DecodeString("pP-rCe4jkKX6mq8yP1GcBZcxJzmxKWicHHor1S3Q49u6Oe-bQsk5NsK5mdR7Y7liGV9n0ikXSM42dYKQdxbhKA-7--fFon5isJoHr4fIwL2CCwVm5QWlK37q6PiH2_F1M0hRorHfkCb4nI56ZvfygvuOH4LIS82OzIgmsYbeEfwDRpeMSxWKwlpa3pX3GZ6jG7FgzJGBvmBkagpgsa2JZdyU4gEGMOkHdSzi5Ii-6RGfFLhhI1OMxC9P2JaU5yjMN2pikfFIq_dbpm75yNUGpWJNVywtrlNvvJfA74UMN_lVCAaSR0A03BUMg6ljB65gFllpKF224uWBA8tpjngwKQ") + if err != nil { + panic(err) + } + + n := big.NewInt(0) + n.SetBytes(keyBytes) + + publicKey := &rsa.PublicKey{ + N: n, + E: 65537, + } + + return oidc.NewVerifier( + "https://accounts.google.com", + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{publicKey}, + }, + config, + ) +} + +func azureIDTokenVerifier(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + keyBytes, err := base64.RawURLEncoding.DecodeString("1djHqyNclRpJWtHCnkP5QWvDxozCTG_ZDnkEmudpcxjnYrVL4RVIwdNCBLAStg8Dob5OUyAlHcRFMCqGTW4HA6kHgIxyfiFsYCBDMHWd2-61N1cAS6S9SdXlWXkBQgU0Qj6q_yFYTRS7J-zI_jMLRQAlpowfDFM1vSTBIci7kqynV6pPOz4jMaDQevmSscEs-jz7e8YXAiiVpN588oBQ0jzQaTTx90WjgRP23mn8mPyabj8gcR3gLwKLsBUhlp1oZj7FopGp8z8LHuueJB_q_LOUa_gAozZ0lfoJxFimXgpgEK7GNVdMRsMH3mIl0A5oYN8f29RFwbG0rNO5ZQ1YWQ") + if err != nil { + panic(err) + } + + n := big.NewInt(0) + n.SetBytes(keyBytes) + + publicKey := &rsa.PublicKey{ + N: n, + E: 65537, + } + + return oidc.NewVerifier( + IssuerAzureMicrosoft, + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{publicKey}, + }, + config, + ) +} + +var realIDTokens map[string]realIDToken = map[string]realIDToken{ + IssuerGoogle: { + AccessToken: "ya29.a0AWY7CklOn4TehiT4kA6osNP6e-pHErOY8X53T2oUe7Oqqwc3-uIJpoEgoZCUogewBuNWr-JFT2FK9s0E0oRSFtAfu0-uIDckBj5ca1pxnk0-zPkPZouqoIyl0AlIpQjIUEuyuQTYUay99kRajbHcFCR1VMbNcQaCgYKAQESARESFQG1tDrp1joUHupV5Rn8-nWDpKkmMw0165", + IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6Ijg1YmE5MzEzZmQ3YTdkNGFmYTg0ODg0YWJjYzg0MDMwMDQzNjMxODAiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI5MTQ2NjY0MjA3NS03OWNwaWs4aWNxYzU4NjY5bjdtaXY5NjZsYmFwOTNhMi5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbSIsImF1ZCI6IjkxNDY2NjQyMDc1LTc5Y3BpazhpY3FjNTg2NjluN21pdjk2NmxiYXA5M2EyLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29tIiwic3ViIjoiMTAzNzgzMTkwMTI2NDM5NzUxMjY5IiwiaGQiOiJzdXBhYmFzZS5pbyIsImVtYWlsIjoic3RvamFuQHN1cGFiYXNlLmlvIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImF0X2hhc2giOiJlcGVWV244VmxWa28zd195Unk3UDZRIiwibmFtZSI6IlN0b2phbiBEaW1pdHJvdnNraSIsInBpY3R1cmUiOiJodHRwczovL2xoMy5nb29nbGV1c2VyY29udGVudC5jb20vYS9BQWNIVHRka0dhWjVlcGtqT1dxSEF1UUV4N2cwRlBCeXJiQ2ZNUjVNTk5kYz1zOTYtYyIsImdpdmVuX25hbWUiOiJTdG9qYW4iLCJmYW1pbHlfbmFtZSI6IkRpbWl0cm92c2tpIiwibG9jYWxlIjoiZW4tR0IiLCJpYXQiOjE2ODY2NTk5MzIsImV4cCI6MTY4NjY2MzUzMn0.nKAN9BFSxvavXYfWX4fZHREYY_3O4uOFRFq1KU1NNrBOMq_CPpM8c8PV7ZhKQvGCjBthSjtxGWbcqT0ByA7RdpNW6kj5UpFxEPdhenZ-eO1FwiEVIC8uZpiX6J3Nr7fAqi1P0DVeB3Zr_GrtkS9MDhZNb3hE5NDkvjCulwP4gRBC-5Pn_aRJRESxYkr_naKiSSmVilkmNVjZO4orq6KuYlvWHKHZIRiUI1akt0gVr5GxsEpd_duzUU30yVSPiq8l6fgxvJn2hT0MHa77wo3hvlP0NyAoSE7Nh4tRSowB0Qq7_byDMUmNWfXh-Qqa2M6ywuJ-_3LTLNUJH-cwdm2tNQ", + Time: time.Unix(1686659933, 0), // 1 sec after iat + Verifier: googleIDTokenVerifier, + }, + IssuerAzureMicrosoft: { + AccessToken: "access-token", + Time: time.Unix(1697277774, 0), // 1 sec after iat + IDToken: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IlhvdVhMWVExVGlwNW9kWWFqaUN0RlZnVmFFcyJ9.eyJ2ZXIiOiIyLjAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vOTE4ODA0MGQtNmM2Ny00YzViLWIxMTItMzZhMzA0YjY2ZGFkL3YyLjAiLCJzdWIiOiJBQUFBQUFBQUFBQUFBQUFBQUFBQUFCWkRuRDkxOTBfc2wxcTZwenZlRHZNIiwiYXVkIjoiYTBkOGY5NzItNTRhYy00YWJmLTkxNGMtNTIyMDE0YzQwMjJhIiwiZXhwIjoxNjk3MzY0NDczLCJpYXQiOjE2OTcyNzc3NzMsIm5iZiI6MTY5NzI3Nzc3MywiZW1haWwiOiJzZGltaXRyb3Zza2lAZ21haWwuY29tIiwidGlkIjoiOTE4ODA0MGQtNmM2Ny00YzViLWIxMTItMzZhMzA0YjY2ZGFkIiwieG1zX2Vkb3YiOiIxIiwiYWlvIjoiRHBQV3lZSnRJcUl5OHpyVjROIUlIdGtFa09BMDhPS29lZ1RkYmZQUEVPYmxtYk9ESFQ0cGJVcVI1cExraENyWWZ6bUgzb3A1RzN5RGp2M0tNZ0Rad29lQ1FjKmVueldyb21iQ3BuKkR6OEpQOGMxU3pEVG1TbGp4U3U3UnVLTXNZSjRvS1lDazFBSVcqUUNUTmlMWkpUKlN3WWZQcjZBTW9IejFEZ3pBZEFkbk9uWiFHNUNFeEtQalBxcHRuVmpUZlEkJCJ9.CskICxOaeqd4SkiPdWEHJKZVdhAdgzM5SN7K7FYi0dguQH1-v6XTetDIoEsBn0GZoozXjbG2GgkFcVhhBvNA0ZrDIr4KcjfnJ5-7rwX3AtxdQ3umrHRlGu3jlmbDOtWzPWNMLLRXfR1Mm3pHEUvlzqmk3Ffh4TuAmXID-fb-Xmfuuv1k0UsZ5mlr_3ybTPVZk-Lj0bqkR1L5Zzt4HjgfpchRryJ3Y24b4dDsSjg7mgE_5JivgjhtVef5OnqYhKUF1DTy2pFysFO_eRliK6qjouYeZnQOJnWHP1MgpySAOQ3sVcwvE4P9g7V3QouxByZPv-g99N1K4GwZrtdm46gtTQ", + Verifier: azureIDTokenVerifier, + }, +} + +func TestParseIDToken(t *testing.T) { + defer func() { + OverrideVerifiers = make(map[string]func(context.Context, *oidc.Config) *oidc.IDTokenVerifier) + OverrideClock = nil + }() + + // note that this test can fail if/when the issuers rotate their + // signing keys (which happens rarely if ever) + // then you should obtain new ID tokens and update this test + for issuer, token := range realIDTokens { + oidcProvider, err := oidc.NewProvider(context.Background(), issuer) + require.NoError(t, err) + + OverrideVerifiers[oidcProvider.Endpoint().AuthURL] = token.Verifier + + _, user, err := ParseIDToken(context.Background(), oidcProvider, &oidc.Config{ + SkipClientIDCheck: true, + Now: func() time.Time { + return token.Time + }, + }, token.IDToken, ParseIDTokenOptions{ + AccessToken: token.AccessToken, + }) + require.NoError(t, err) + + require.NotEmpty(t, user.Emails[0].Email) + require.Equal(t, user.Emails[0].Verified, true) + } +} + +func TestAzureIDTokenClaimsIsEmailVerified(t *testing.T) { + positiveExamples := []AzureIDTokenClaims{ + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: nil, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: true, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "1", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "true", + }, + } + + negativeExamples := []AzureIDTokenClaims{ + { + Email: "", + XMicrosoftEmailDomainOwnerVerified: true, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: false, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "0", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "false", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: float32(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: float64(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int32(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int64(0), + }, + } + + for i, example := range positiveExamples { + if !example.IsEmailVerified() { + t.Errorf("positive example %v reports negative result", i) + } + } + + for i, example := range negativeExamples { + if example.IsEmailVerified() { + t.Errorf("negative example %v reports positive result", i) + } + } +} diff --git a/internal/api/provider/provider.go b/internal/api/provider/provider.go new file mode 100644 index 000000000..857e88252 --- /dev/null +++ b/internal/api/provider/provider.go @@ -0,0 +1,128 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "os" + "time" + + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +var defaultTimeout time.Duration = time.Second * 10 + +func init() { + timeoutStr := os.Getenv("GOTRUE_INTERNAL_HTTP_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_INTERNAL_HTTP_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } +} + +type Claims struct { + // Reserved claims + Issuer string `json:"iss,omitempty" structs:"iss,omitempty"` + Subject string `json:"sub,omitempty" structs:"sub,omitempty"` + Aud string `json:"aud,omitempty" structs:"aud,omitempty"` + Iat float64 `json:"iat,omitempty" structs:"iat,omitempty"` + Exp float64 `json:"exp,omitempty" structs:"exp,omitempty"` + + // Default profile claims + Name string `json:"name,omitempty" structs:"name,omitempty"` + FamilyName string `json:"family_name,omitempty" structs:"family_name,omitempty"` + GivenName string `json:"given_name,omitempty" structs:"given_name,omitempty"` + MiddleName string `json:"middle_name,omitempty" structs:"middle_name,omitempty"` + NickName string `json:"nickname,omitempty" structs:"nickname,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty" structs:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty" structs:"profile,omitempty"` + Picture string `json:"picture,omitempty" structs:"picture,omitempty"` + Website string `json:"website,omitempty" structs:"website,omitempty"` + Gender string `json:"gender,omitempty" structs:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty" structs:"birthdate,omitempty"` + ZoneInfo string `json:"zoneinfo,omitempty" structs:"zoneinfo,omitempty"` + Locale string `json:"locale,omitempty" structs:"locale,omitempty"` + UpdatedAt string `json:"updated_at,omitempty" structs:"updated_at,omitempty"` + Email string `json:"email,omitempty" structs:"email,omitempty"` + EmailVerified bool `json:"email_verified,omitempty" structs:"email_verified"` + Phone string `json:"phone,omitempty" structs:"phone,omitempty"` + PhoneVerified bool `json:"phone_verified,omitempty" structs:"phone_verified"` + + // Custom profile claims that are provider specific + CustomClaims map[string]interface{} `json:"custom_claims,omitempty" structs:"custom_claims,omitempty"` + + // TODO: Deprecate in next major release + FullName string `json:"full_name,omitempty" structs:"full_name,omitempty"` + AvatarURL string `json:"avatar_url,omitempty" structs:"avatar_url,omitempty"` + Slug string `json:"slug,omitempty" structs:"slug,omitempty"` + ProviderId string `json:"provider_id,omitempty" structs:"provider_id,omitempty"` + UserNameKey string `json:"user_name,omitempty" structs:"user_name,omitempty"` +} + +// Email is a struct that provides information on whether an email is verified or is the primary email address +type Email struct { + Email string + Verified bool + Primary bool +} + +// UserProvidedData is a struct that contains the user's data returned from the oauth provider +type UserProvidedData struct { + Emails []Email + Metadata *Claims +} + +// Provider is an interface for interacting with external account providers +type Provider interface { + AuthCodeURL(string, ...oauth2.AuthCodeOption) string +} + +// OAuthProvider specifies additional methods needed for providers using OAuth +type OAuthProvider interface { + AuthCodeURL(string, ...oauth2.AuthCodeOption) string + GetUserData(context.Context, *oauth2.Token) (*UserProvidedData, error) + GetOAuthToken(string) (*oauth2.Token, error) +} + +func chooseHost(base, defaultHost string) string { + if base == "" { + return "https://" + defaultHost + } + + baseLen := len(base) + if base[baseLen-1] == '/' { + return base[:baseLen-1] + } + + return base +} + +func makeRequest(ctx context.Context, tok *oauth2.Token, g *oauth2.Config, url string, dst interface{}) error { + client := g.Client(ctx, tok) + client.Timeout = defaultTimeout + res, err := client.Get(url) + if err != nil { + return err + } + defer utilities.SafeClose(res.Body) + + bodyBytes, _ := io.ReadAll(res.Body) + res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + return httpError(res.StatusCode, string(bodyBytes)) + } + + if err := json.NewDecoder(res.Body).Decode(dst); err != nil { + return err + } + + return nil +} diff --git a/api/provider/slack.go b/internal/api/provider/slack.go similarity index 64% rename from api/provider/slack.go rename to internal/api/provider/slack.go index 896bc8d09..40377b0aa 100644 --- a/api/provider/slack.go +++ b/internal/api/provider/slack.go @@ -2,11 +2,9 @@ package provider import ( "context" - "errors" - "fmt" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -22,11 +20,12 @@ type slackUser struct { Email string `json:"email"` Name string `json:"name"` AvatarURL string `json:"picture"` + TeamID string `json:"https://slack.com/team_id"` } -// NewSlackProvider creates a Slack account provider. +// NewSlackProvider creates a Slack account provider with Legacy Slack OAuth. func NewSlackProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -45,7 +44,7 @@ func NewSlackProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth return &slackProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authPath + "/authorize", @@ -59,7 +58,7 @@ func NewSlackProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth } func (g slackProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g slackProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -67,29 +66,29 @@ func (g slackProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/openid.connect.userInfo", &u); err != nil { return nil, err } - fmt.Printf("%+v\n", u) - if u.Email == "" { - return nil, errors.New("Unable to find email with Slack provider") - } - - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.ID, - Name: u.Name, - Picture: u.AvatarURL, - Email: u.Email, - EmailVerified: true, // Slack dosen't provide data on if email is verified. - // To be deprecated - AvatarURL: u.AvatarURL, - FullName: u.Name, - ProviderId: u.ID, - }, - Emails: []Email{{ + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ Email: u.Email, - Verified: true, // Slack dosen't provide data on if email is verified. + Verified: true, // Slack doesn't provide data on if email is verified. Primary: true, - }}, - }, nil + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + CustomClaims: map[string]interface{}{ + "https://slack.com/team_id": u.TeamID, + }, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil } diff --git a/internal/api/provider/slack_oidc.go b/internal/api/provider/slack_oidc.go new file mode 100644 index 000000000..3c7a5eb62 --- /dev/null +++ b/internal/api/provider/slack_oidc.go @@ -0,0 +1,99 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const defaultSlackOIDCApiBase = "slack.com" + +type slackOIDCProvider struct { + *oauth2.Config + APIPath string +} + +type slackOIDCUser struct { + ID string `json:"https://slack.com/user_id"` + TeamID string `json:"https://slack.com/team_id"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + AvatarURL string `json:"picture"` +} + +// NewSlackOIDCProvider creates a Slack account provider with Sign in with Slack. +func NewSlackOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultSlackOIDCApiBase) + "/api" + authPath := chooseHost(ext.URL, defaultSlackOIDCApiBase) + "/openid" + + // these are required scopes for slack's OIDC flow + // see https://api.slack.com/authentication/sign-in-with-slack#implementation + oauthScopes := []string{ + "profile", + "email", + "openid", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &slackOIDCProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/connect/authorize", + TokenURL: apiPath + "/openid.connect.token", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g slackOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g slackOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u slackOIDCUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/openid.connect.userInfo", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + // email_verified is returned as part of the response + // see: https://api.slack.com/authentication/sign-in-with-slack#response + Verified: u.EmailVerified, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + CustomClaims: map[string]interface{}{ + "https://slack.com/team_id": u.TeamID, + }, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil +} diff --git a/api/provider/spotify.go b/internal/api/provider/spotify.go similarity index 64% rename from api/provider/spotify.go rename to internal/api/provider/spotify.go index 0334b21c6..e6d2f383c 100644 --- a/api/provider/spotify.go +++ b/internal/api/provider/spotify.go @@ -2,10 +2,9 @@ package provider import ( "context" - "errors" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -34,7 +33,7 @@ type spotifyUserImage struct { // NewSpotifyProvider creates a Spotify account provider. func NewSpotifyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -51,7 +50,7 @@ func NewSpotifyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu return &spotifyProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authPath + "/authorize", @@ -65,7 +64,7 @@ func NewSpotifyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu } func (g spotifyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(oauth2.NoContext, code) + return g.Exchange(context.Background(), code) } func (g spotifyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { @@ -74,34 +73,42 @@ func (g spotifyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*U return nil, err } - if u.Email == "" { - return nil, errors.New("Unable to find email with Spotify provider") + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + // Spotify dosen't provide data on whether the user's email is verified. + // https://developer.spotify.com/documentation/web-api/reference/get-current-users-profile + Verified: false, + Primary: true, + }} } var avatarURL string + // Spotify returns a list of avatars, we want to use the largest one if len(u.Avatars) >= 1 { - avatarURL = u.Avatars[0].Url + largestAvatar := u.Avatars[0] + + for _, avatar := range u.Avatars { + if avatar.Height*avatar.Width > largestAvatar.Height*largestAvatar.Width { + largestAvatar = avatar + } + } + + avatarURL = largestAvatar.Url } - return &UserProvidedData{ - Metadata: &Claims{ - Issuer: g.APIPath, - Subject: u.ID, - Name: u.DisplayName, - Picture: avatarURL, - Email: u.Email, - EmailVerified: true, // Spotify dosen't provide data on if email is verified. - - // To be deprecated - AvatarURL: avatarURL, - FullName: u.DisplayName, - ProviderId: u.ID, - }, - Emails: []Email{{ - Email: u.Email, - Verified: true, // Spotify dosen't provide data on if email is verified. - Primary: true, - }}, - }, nil + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.DisplayName, + Picture: avatarURL, + + // To be deprecated + AvatarURL: avatarURL, + FullName: u.DisplayName, + ProviderId: u.ID, + } + return data, nil } diff --git a/api/provider/twitch.go b/internal/api/provider/twitch.go similarity index 66% rename from api/provider/twitch.go rename to internal/api/provider/twitch.go index 01bb806d6..defb1983a 100644 --- a/api/provider/twitch.go +++ b/internal/api/provider/twitch.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "errors" - "io/ioutil" + "fmt" + "io" "net/http" "strings" "time" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" "golang.org/x/oauth2" ) @@ -43,7 +45,7 @@ type twitchUsers struct { // NewTwitchProvider creates a Twitch account provider. func NewTwitchProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -60,7 +62,7 @@ func NewTwitchProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut return &twitchProvider{ Config: &oauth2.Config{ - ClientID: ext.ClientID, + ClientID: ext.ClientID[0], ClientSecret: ext.Secret, Endpoint: oauth2.Endpoint{ AuthURL: authHost + "/oauth2/authorize", @@ -91,16 +93,26 @@ func (t twitchProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us req.Header.Set("Client-Id", t.Config.ClientID) req.Header.Set("Authorization", "Bearer "+tok.AccessToken) - client := &http.Client{} + client := &http.Client{Timeout: defaultTimeout} resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer utilities.SafeClose(resp.Body) + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("a %v error occurred with retrieving user from twitch", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - body, _ := ioutil.ReadAll(resp.Body) - json.Unmarshal(body, &u) - defer resp.Body.Close() + err = json.Unmarshal(body, &u) + if err != nil { + return nil, err + } if len(u.Data) == 0 { return nil, errors.New("unable to find user with twitch provider") @@ -108,38 +120,34 @@ func (t twitchProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us user := u.Data[0] - if user.Email == "" { - return nil, errors.New("unable to find email with twitch provider") - } - - data := &UserProvidedData{ - Metadata: &Claims{ - Issuer: t.APIHost, - Subject: user.ID, - Picture: user.ProfileImageURL, - Name: user.Login, - NickName: user.DisplayName, - Email: user.Email, - EmailVerified: true, - CustomClaims: map[string]interface{}{ - "broadcaster_type": user.BroadcasterType, - "description": user.Description, - "type": user.Type, - "offline_image_url": user.OfflineImageURL, - "view_count": user.ViewCount, - }, - - // To be deprecated - Slug: user.DisplayName, - AvatarURL: user.ProfileImageURL, - FullName: user.Login, - ProviderId: user.ID, - }, - Emails: []Email{{ + data := &UserProvidedData{} + if user.Email != "" { + data.Emails = []Email{{ Email: user.Email, Verified: true, Primary: true, - }}, + }} + } + + data.Metadata = &Claims{ + Issuer: t.APIHost, + Subject: user.ID, + Picture: user.ProfileImageURL, + Name: user.Login, + NickName: user.DisplayName, + CustomClaims: map[string]interface{}{ + "broadcaster_type": user.BroadcasterType, + "description": user.Description, + "type": user.Type, + "offline_image_url": user.OfflineImageURL, + "view_count": user.ViewCount, + }, + + // To be deprecated + Slug: user.DisplayName, + AvatarURL: user.ProfileImageURL, + FullName: user.Login, + ProviderId: user.ID, } return data, nil diff --git a/api/provider/twitter.go b/internal/api/provider/twitter.go similarity index 77% rename from api/provider/twitter.go rename to internal/api/provider/twitter.go index a3a752c32..8dc5a4c64 100644 --- a/api/provider/twitter.go +++ b/internal/api/provider/twitter.go @@ -4,14 +4,14 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" - "io/ioutil" + "io" "net/http" "strings" "github.com/mrjones/oauth" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" "golang.org/x/oauth2" ) @@ -19,7 +19,7 @@ const ( defaultTwitterAPIBase = "api.twitter.com" requestURL = "/oauth/request_token" authenticateURL = "/oauth/authenticate" - tokenURL = "/oauth/access_token" + tokenURL = "/oauth/access_token" //#nosec G101 -- Not a secret value. endpointProfile = "/1.1/account/verify_credentials.json" ) @@ -45,12 +45,12 @@ type twitterUser struct { // NewTwitterProvider creates a Twitter account provider. func NewTwitterProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { - if err := ext.Validate(); err != nil { + if err := ext.ValidateOAuth(); err != nil { return nil, err } authHost := chooseHost(ext.URL, defaultTwitterAPIBase) p := &TwitterProvider{ - ClientKey: ext.ClientID, + ClientKey: ext.ClientID[0], Secret: ext.Secret, CallbackURL: ext.RedirectURI, UserInfoURL: authHost + endpointProfile, @@ -80,41 +80,37 @@ func (t TwitterProvider) FetchUserData(ctx context.Context, tok *oauth.AccessTok if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return &UserProvidedData{}, fmt.Errorf("Twitter responded with a %d trying to fetch user information", resp.StatusCode) + defer utilities.SafeClose(resp.Body) + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return &UserProvidedData{}, fmt.Errorf("a %v error occurred with retrieving user from twitter", resp.StatusCode) } - bits, err := ioutil.ReadAll(resp.Body) + bits, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - err = json.NewDecoder(bytes.NewReader(bits)).Decode(&u) + _ = json.NewDecoder(bytes.NewReader(bits)).Decode(&u) - if u.Email == "" { - return nil, errors.New("Unable to find email with Twitter provider") - } - - data := &UserProvidedData{ - Metadata: &Claims{ - Issuer: t.UserInfoURL, - Subject: u.ID, - Name: u.Name, - Picture: u.AvatarURL, - PreferredUsername: u.UserName, - Email: u.Email, - EmailVerified: true, - - // To be deprecated - UserNameKey: u.UserName, - FullName: u.Name, - AvatarURL: u.AvatarURL, - ProviderId: u.ID, - }, - Emails: []Email{{ + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ Email: u.Email, Verified: true, Primary: true, - }}, + }} + } + + data.Metadata = &Claims{ + Issuer: t.UserInfoURL, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + PreferredUsername: u.UserName, + + // To be deprecated + UserNameKey: u.UserName, + FullName: u.Name, + AvatarURL: u.AvatarURL, + ProviderId: u.ID, } return data, nil diff --git a/internal/api/provider/vercel_marketplace.go b/internal/api/provider/vercel_marketplace.go new file mode 100644 index 000000000..ba76a7412 --- /dev/null +++ b/internal/api/provider/vercel_marketplace.go @@ -0,0 +1,78 @@ +package provider + +import ( + "context" + "errors" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultVercelMarketplaceAPIBase = "api.vercel.com" + IssuerVercelMarketplace = "https://marketplace.vercel.com" +) + +type vercelMarketplaceProvider struct { + *oauth2.Config + oidc *oidc.Provider + APIPath string +} + +// NewVercelMarketplaceProvider creates a VercelMarketplace account provider via OIDC. +func NewVercelMarketplaceProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultVercelMarketplaceAPIBase) + + oauthScopes := []string{} + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(context.Background(), IssuerVercelMarketplace) + if err != nil { + return nil, err + } + + return &vercelMarketplaceProvider{ + oidc: oidcProvider, + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth/v2/authorization", + TokenURL: apiPath + "/oauth/v2/accessToken", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g vercelMarketplaceProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g vercelMarketplaceProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + return nil, errors.New("vercel_marketplace: no OIDC ID token present in response") + } + + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/internal/api/provider/workos.go b/internal/api/provider/workos.go new file mode 100644 index 000000000..75cafa27d --- /dev/null +++ b/internal/api/provider/workos.go @@ -0,0 +1,98 @@ +package provider + +import ( + "context" + "strings" + + "github.com/mitchellh/mapstructure" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultWorkOSAPIBase = "api.workos.com" +) + +type workosProvider struct { + *oauth2.Config + APIPath string +} + +// See https://workos.com/docs/reference/sso/profile. +type workosUser struct { + ID string `mapstructure:"id"` + ConnectionID string `mapstructure:"connection_id"` + OrganizationID string `mapstructure:"organization_id"` + ConnectionType string `mapstructure:"connection_type"` + Email string `mapstructure:"email"` + FirstName string `mapstructure:"first_name"` + LastName string `mapstructure:"last_name"` + Object string `mapstructure:"object"` + IdpID string `mapstructure:"idp_id"` + RawAttributes map[string]interface{} `mapstructure:"raw_attributes"` +} + +// NewWorkOSProvider creates a WorkOS account provider. +func NewWorkOSProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + apiPath := chooseHost(ext.URL, defaultWorkOSAPIBase) + + return &workosProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/sso/authorize", + TokenURL: apiPath + "/sso/token", + }, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g workosProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g workosProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + if tok.AccessToken == "" { + return &UserProvidedData{}, nil + } + + // WorkOS API returns the user's profile data along with the OAuth2 token, so + // we can just convert from `map[string]interface{}` to `workosUser` without + // an additional network request. + var u workosUser + err := mapstructure.Decode(tok.Extra("profile"), &u) + if err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + CustomClaims: map[string]interface{}{ + "connection_id": u.ConnectionID, + "organization_id": u.OrganizationID, + }, + + // To be deprecated + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + + return data, nil +} diff --git a/internal/api/provider/zoom.go b/internal/api/provider/zoom.go new file mode 100644 index 000000000..8e2e9fa4d --- /dev/null +++ b/internal/api/provider/zoom.go @@ -0,0 +1,91 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultZoomAuthBase = "zoom.us" + defaultZoomAPIBase = "api.zoom.us" +) + +type zoomProvider struct { + *oauth2.Config + APIPath string +} + +type zoomUser struct { + ID string `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + EmailVerified int `json:"verified"` + LoginType string `json:"login_type"` + AvatarURL string `json:"pic_url"` +} + +// NewZoomProvider creates a Zoom account provider. +func NewZoomProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultZoomAPIBase) + "/v2" + authPath := chooseHost(ext.URL, defaultZoomAuthBase) + "/oauth" + + return &zoomProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/authorize", + TokenURL: authPath + "/token", + }, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g zoomProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g zoomProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u zoomUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/users/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + email := Email{} + email.Email = u.Email + email.Primary = true + // A login_type of "100" refers to email-based logins, not oauth. + // A user is verified (type 1) only if they received an email when their profile was created and confirmed the link. + // A zoom user will only be sent an email confirmation link if they signed up using their zoom work email and not oauth. + // See: https://devforum.zoom.us/t/how-to-determine-if-a-zoom-user-actually-owns-their-email-address/44430 + if u.LoginType != "100" || u.EmailVerified != 0 { + email.Verified = true + } + data.Emails = []Email{email} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + Picture: u.AvatarURL, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + return data, nil +} diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go new file mode 100644 index 000000000..57eb90505 --- /dev/null +++ b/internal/api/reauthenticate.go @@ -0,0 +1,98 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +const InvalidNonceMessage = "Nonce has expired or is invalid" + +// Reauthenticate sends a reauthentication otp to either the user's email or phone +func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + user := getUser(ctx) + email, phone := user.GetEmail(), user.GetPhone() + + if email == "" && phone == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") + } + + if email != "" { + if !user.IsConfirmed() { + return unprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Please verify your email first.") + } + } else if phone != "" { + if !user.IsPhoneConfirmed() { + return unprocessableEntityError(apierrors.ErrorCodePhoneNotConfirmed, "Please verify your phone first.") + } + } + + messageID := "" + err := db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserReauthenticateAction, "", nil); terr != nil { + return terr + } + if email != "" { + return a.sendReauthenticationOtp(r, tx, user) + } else if phone != "" { + mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, sms_provider.SMSProvider) + if err != nil { + return err + } + + messageID = mID + } + return nil + }) + if err != nil { + return err + } + + ret := map[string]any{} + if messageID != "" { + ret["message_id"] = messageID + + } + + return sendJSON(w, http.StatusOK, ret) +} + +// verifyReauthentication checks if the nonce provided is valid +func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { + if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { + return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + } + var isValid bool + if user.GetEmail() != "" { + tokenHash := crypto.GenerateTokenHash(user.GetEmail(), nonce) + isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Mailer.OtpExp) + } else if user.GetPhone() != "" { + if config.Sms.IsTwilioVerifyProvider() { + smsProvider, _ := sms_provider.GetSmsProvider(*config) + if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { + return forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return nil + } else { + tokenHash := crypto.GenerateTokenHash(user.GetPhone(), nonce) + isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) + } + } else { + return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") + } + if !isValid { + return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + } + if err := user.ConfirmReauthentication(tx); err != nil { + return internalServerError("Error during reauthentication").WithInternalError(err) + } + return nil +} diff --git a/internal/api/recover.go b/internal/api/recover.go new file mode 100644 index 000000000..fed11d355 --- /dev/null +++ b/internal/api/recover.go @@ -0,0 +1,74 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// RecoverParams holds the parameters for a password recovery request +type RecoverParams struct { + Email string `json:"email"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +func (p *RecoverParams) Validate(a *API) error { + if p.Email == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") + } + var err error + if p.Email, err = a.validateEmail(p.Email); err != nil { + return err + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +// Recover sends a recovery email +func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + params := &RecoverParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + flowType := getFlowFromChallenge(params.CodeChallenge) + if err := params.Validate(a); err != nil { + return err + } + + var user *models.User + var err error + aud := a.requestAud(ctx, r) + + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + return internalServerError("Unable to process request").WithInternalError(err) + } + if isPKCEFlow(flowType) { + if _, err := generateFlowState(db, models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + return a.sendPasswordRecovery(r, tx, user, flowType) + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]string{}) +} diff --git a/api/recover_test.go b/internal/api/recover_test.go similarity index 66% rename from api/recover_test.go rename to internal/api/recover_test.go index 548d77278..a7e655c59 100644 --- a/api/recover_test.go +++ b/internal/api/recover_test.go @@ -8,30 +8,26 @@ import ( "testing" "time" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" ) type RecoverTestSuite struct { suite.Suite API *API - Config *conf.Configuration - - instanceID uuid.UUID + Config *conf.GlobalConfiguration } func TestRecover(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() + api, config, err := setupAPIForTest() require.NoError(t, err) ts := &RecoverTestSuite{ - API: api, - Config: config, - instanceID: instanceID, + API: api, + Config: config, } defer api.db.Close() @@ -42,13 +38,13 @@ func (ts *RecoverTestSuite) SetupTest() { models.TruncateAll(ts.API.db) // Create user - u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil) + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error creating test user model") require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") } func (ts *RecoverTestSuite) TestRecover_FirstRecovery() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoverySentAt = &time.Time{} require.NoError(ts.T(), ts.API.db.Update(u)) @@ -68,7 +64,7 @@ func (ts *RecoverTestSuite) TestRecover_FirstRecovery() { ts.API.handler.ServeHTTP(w, req) assert.Equal(ts.T(), http.StatusOK, w.Code) - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) @@ -76,7 +72,7 @@ func (ts *RecoverTestSuite) TestRecover_FirstRecovery() { func (ts *RecoverTestSuite) TestRecover_NoEmailSent() { recoveryTime := time.Now().UTC().Add(-59 * time.Second) - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoverySentAt = &recoveryTime require.NoError(ts.T(), ts.API.db.Update(u)) @@ -96,7 +92,7 @@ func (ts *RecoverTestSuite) TestRecover_NoEmailSent() { ts.API.handler.ServeHTTP(w, req) assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) // ensure it did not send a new email @@ -107,7 +103,7 @@ func (ts *RecoverTestSuite) TestRecover_NoEmailSent() { func (ts *RecoverTestSuite) TestRecover_NewEmailSent() { recoveryTime := time.Now().UTC().Add(-20 * time.Minute) - u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoverySentAt = &recoveryTime require.NoError(ts.T(), ts.API.db.Update(u)) @@ -127,9 +123,31 @@ func (ts *RecoverTestSuite) TestRecover_NewEmailSent() { ts.API.handler.ServeHTTP(w, req) assert.Equal(ts.T(), http.StatusOK, w.Code) - u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) // ensure it sent a new email assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) } + +func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak() { + email := "doesntexist@example.com" + + _, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + require.True(ts.T(), models.IsNotFoundError(err), "User with email %s does exist", email) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": email, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} diff --git a/internal/api/resend.go b/internal/api/resend.go new file mode 100644 index 000000000..bc3ed45cb --- /dev/null +++ b/internal/api/resend.go @@ -0,0 +1,155 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// ResendConfirmationParams holds the parameters for a resend request +type ResendConfirmationParams struct { + Type string `json:"type"` + Email string `json:"email"` + Phone string `json:"phone"` +} + +func (p *ResendConfirmationParams) Validate(a *API) error { + config := a.config + + switch p.Type { + case mail.SignupVerification, mail.EmailChangeVerification, smsVerification, phoneChangeVerification: + break + default: + // type does not match one of the above + return badRequestError(apierrors.ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") + + } + if p.Email == "" && p.Type == mail.SignupVerification { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires an email address") + } + if p.Phone == "" && p.Type == smsVerification { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires a phone number") + } + + var err error + if p.Email != "" && p.Phone != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") + } else if p.Email != "" { + if !config.External.Email.Enabled { + return badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + } else if p.Phone != "" { + if !config.External.Phone.Enabled { + return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + } + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + } else { + // both email and phone are empty + return badRequestError(apierrors.ErrorCodeValidationFailed, "Missing email address or phone number") + } + return nil +} + +// Recover sends a recovery email +func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + params := &ResendConfirmationParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.Validate(a); err != nil { + return err + } + + var user *models.User + var err error + aud := a.requestAud(ctx, r) + if params.Email != "" { + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } + + if err != nil { + if models.IsNotFoundError(err) { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + return internalServerError("Unable to process request").WithInternalError(err) + } + + switch params.Type { + case mail.SignupVerification: + if user.IsConfirmed() { + // if the user's email is confirmed already, we don't need to send a confirmation email again + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case smsVerification: + if user.IsPhoneConfirmed() { + // if the user's phone is confirmed already, we don't need to send a confirmation sms again + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case mail.EmailChangeVerification: + // do not resend if user doesn't have a new email address + if user.EmailChange == "" { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case phoneChangeVerification: + // do not resend if user doesn't have a new phone number + if user.PhoneChange == "" { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + } + + messageID := "" + err = db.Transaction(func(tx *storage.Connection) error { + switch params.Type { + case mail.SignupVerification: + if terr := models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", nil); terr != nil { + return terr + } + // PKCE not implemented yet + return a.sendConfirmation(r, tx, user, models.ImplicitFlow) + case smsVerification: + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, sms_provider.SMSProvider) + if terr != nil { + return terr + } + messageID = mID + case mail.EmailChangeVerification: + return a.sendEmailChange(r, tx, user, user.EmailChange, models.ImplicitFlow) + case phoneChangeVerification: + mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, sms_provider.SMSProvider) + if terr != nil { + return terr + } + messageID = mID + } + return nil + }) + if err != nil { + return err + } + + ret := map[string]any{} + if messageID != "" { + ret["message_id"] = messageID + } + + return sendJSON(w, http.StatusOK, ret) +} diff --git a/internal/api/resend_test.go b/internal/api/resend_test.go new file mode 100644 index 000000000..83c58c4e4 --- /dev/null +++ b/internal/api/resend_test.go @@ -0,0 +1,217 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type ResendTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestResend(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &ResendTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *ResendTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) +} + +func (ts *ResendTestSuite) TestResendValidation() { + cases := []struct { + desc string + params map[string]interface{} + expected map[string]interface{} + }{ + { + desc: "Invalid type", + params: map[string]interface{}{ + "type": "invalid", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Missing one of these types: signup, email_change, sms, phone_change", + }, + }, + { + desc: "Type & email mismatch", + params: map[string]interface{}{ + "type": "sms", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Type provided requires a phone number", + }, + }, + { + desc: "Phone & email change type", + params: map[string]interface{}{ + "type": "email_change", + "phone": "+123456789", + }, + expected: map[string]interface{}{ + "code": http.StatusOK, + "message": nil, + }, + }, + { + desc: "Email & phone number provided", + params: map[string]interface{}{ + "type": "email_change", + "phone": "+123456789", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Only an email address or phone number should be provided.", + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "http://localhost/resend", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected["code"], w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expected["message"], data["msg"]) + }) + } + +} + +func (ts *ResendTestSuite) TestResendSuccess() { + // Create user + u, err := models.NewUser("123456789", "foo@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + + // Avoid max freq limit error + now := time.Now().Add(-1 * time.Minute) + + // Enable Phone Logoin for phone related tests + ts.Config.External.Phone.Enabled = true + // disable secure email change + ts.Config.Mailer.SecureEmailChangeEnabled = false + + u.ConfirmationToken = "123456" + u.ConfirmationSentAt = &now + u.EmailChange = "bar@example.com" + u.EmailChangeSentAt = &now + u.EmailChangeTokenNew = "123456" + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + phoneUser, err := models.NewUser("1234567890", "", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + phoneUser.EmailChange = "bar@example.com" + phoneUser.EmailChangeSentAt = &now + phoneUser.EmailChangeTokenNew = "123456" + require.NoError(ts.T(), ts.API.db.Create(phoneUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.EmailChange, phoneUser.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + emailUser, err := models.NewUser("", "bar@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + phoneUser.PhoneChange = "1234567890" + phoneUser.PhoneChangeSentAt = &now + phoneUser.PhoneChangeToken = "123456" + require.NoError(ts.T(), ts.API.db.Create(emailUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.PhoneChange, phoneUser.PhoneChangeToken, models.PhoneChangeToken)) + + cases := []struct { + desc string + params map[string]interface{} + // expected map[string]interface{} + user *models.User + }{ + { + desc: "Resend signup confirmation", + params: map[string]interface{}{ + "type": "signup", + "email": u.GetEmail(), + }, + user: u, + }, + { + desc: "Resend email change", + params: map[string]interface{}{ + "type": "email_change", + "email": u.GetEmail(), + }, + user: u, + }, + { + desc: "Resend email change for phone user", + params: map[string]interface{}{ + "type": "email_change", + "phone": phoneUser.GetPhone(), + }, + user: phoneUser, + }, + { + desc: "Resend phone change for email user", + params: map[string]interface{}{ + "type": "phone_change", + "email": emailUser.GetEmail(), + }, + user: emailUser, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "http://localhost/resend", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + switch c.params["type"] { + case mail.SignupVerification, mail.EmailChangeVerification: + dbUser, err := models.FindUserByID(ts.API.db, c.user.ID) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), dbUser) + + if c.params["type"] == mail.SignupVerification { + require.NotEqual(ts.T(), dbUser.ConfirmationToken, c.user.ConfirmationToken) + require.NotEqual(ts.T(), dbUser.ConfirmationSentAt, c.user.ConfirmationSentAt) + } else if c.params["type"] == mail.EmailChangeVerification { + require.NotEqual(ts.T(), dbUser.EmailChangeTokenNew, c.user.EmailChangeTokenNew) + require.NotEqual(ts.T(), dbUser.EmailChangeSentAt, c.user.EmailChangeSentAt) + } + } + }) + } +} diff --git a/api/router.go b/internal/api/router.go similarity index 95% rename from api/router.go rename to internal/api/router.go index c2f06ae2e..1feb66d3f 100644 --- a/api/router.go +++ b/internal/api/router.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/go-chi/chi" + "github.com/go-chi/chi/v5" ) func newRouter() *router { @@ -63,7 +63,7 @@ func handler(fn apiHandler) http.HandlerFunc { func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) } } @@ -78,7 +78,7 @@ func (m middlewareHandler) handler(next http.Handler) http.Handler { func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { ctx, err := m(w, r) if err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) return } if ctx != nil { diff --git a/internal/api/saml.go b/internal/api/saml.go new file mode 100644 index 000000000..f32d4436d --- /dev/null +++ b/internal/api/saml.go @@ -0,0 +1,113 @@ +package api + +import ( + "encoding/xml" + "net/http" + "net/url" + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" +) + +// getSAMLServiceProvider generates a new service provider object with the +// (optionally) provided descriptor (metadata) for the identity provider. +func (a *API) getSAMLServiceProvider(identityProvider *saml.EntityDescriptor, idpInitiated bool) *saml.ServiceProvider { + var externalURL *url.URL + + if a.config.SAML.ExternalURL != "" { + url, err := url.ParseRequestURI(a.config.SAML.ExternalURL) + if err != nil { + // this should not fail as a.config should have been validated using #Validate() + panic(err) + } + + externalURL = url + } else { + url, err := url.ParseRequestURI(a.config.API.ExternalURL) + if err != nil { + // this should not fail as a.config should have been validated using #Validate() + panic(err) + } + + externalURL = url + } + + if !strings.HasSuffix(externalURL.Path, "/") { + externalURL.Path += "/" + } + + externalURL.Path += "sso/" + + provider := samlsp.DefaultServiceProvider(samlsp.Options{ + URL: *externalURL, + Key: a.config.SAML.RSAPrivateKey, + Certificate: a.config.SAML.Certificate, + SignRequest: true, + AllowIDPInitiated: idpInitiated, + IDPMetadata: identityProvider, + }) + + provider.AuthnNameIDFormat = saml.PersistentNameIDFormat + + return &provider +} + +// SAMLMetadata serves GoTrue's SAML Service Provider metadata file. +func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error { + serviceProvider := a.getSAMLServiceProvider(nil, true) + + metadata := serviceProvider.Metadata() + + if r.FormValue("download") == "true" { + // 5 year expiration, comparable to what GSuite does + metadata.ValidUntil = time.Now().UTC().AddDate(5, 0, 0) + } + + for i := range metadata.SPSSODescriptors { + // we set this to false since the IdP initiated flow can only + // sign the Assertion, and not the full Request + // unfortunately this is hardcoded in the crewjam library if + // signatures (instead of encryption) are supported + // https://github.com/crewjam/saml/blob/v0.4.8/service_provider.go#L217 + metadata.SPSSODescriptors[i].AuthnRequestsSigned = nil + + // advertize the requested NameID formats (either persistent or email address) + metadata.SPSSODescriptors[i].NameIDFormats = []saml.NameIDFormat{ + saml.EmailAddressNameIDFormat, + saml.PersistentNameIDFormat, + } + } + + for i := range metadata.SPSSODescriptors { + spd := &metadata.SPSSODescriptors[i] + + var keyDescriptors []saml.KeyDescriptor + + for _, kd := range spd.KeyDescriptors { + // only advertize key as usable for encryption if allowed + if kd.Use == "signing" || (a.config.SAML.AllowEncryptedAssertions && kd.Use == "encryption") { + keyDescriptors = append(keyDescriptors, kd) + } + } + + spd.KeyDescriptors = keyDescriptors + } + + metadataXML, err := xml.Marshal(metadata) + if err != nil { + return err + } + + w.Header().Set("Content-Type", "application/xml") + w.Header().Set("Cache-Control", "public, max-age=600") // cache at CDN for 10 minutes + + if r.FormValue("download") == "true" { + w.Header().Set("Content-Disposition", "attachment; filename=\"metadata.xml\"") + } + + _, err = w.Write(metadataXML) + + return err +} diff --git a/internal/api/saml_test.go b/internal/api/saml_test.go new file mode 100644 index 000000000..a290fb2e8 --- /dev/null +++ b/internal/api/saml_test.go @@ -0,0 +1,59 @@ +package api + +import ( + tst "testing" + "time" + + "encoding/xml" + "net/http" + "net/http/httptest" + + "github.com/crewjam/saml" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestSAMLMetadataWithAPI(t *tst.T) { + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + config.API.ExternalURL = "https://projectref.supabase.co/auth/v1/" + config.SAML.Enabled = true + config.SAML.PrivateKey = "MIIEowIBAAKCAQEAszrVveMQcSsa0Y+zN1ZFb19cRS0jn4UgIHTprW2tVBmO2PABzjY3XFCfx6vPirMAPWBYpsKmXrvm1tr0A6DZYmA8YmJd937VUQ67fa6DMyppBYTjNgGEkEhmKuszvF3MARsIKCGtZqUrmS7UG4404wYxVppnr2EYm3RGtHlkYsXu20MBqSDXP47bQP+PkJqC3BuNGk3xt5UHl2FSFpTHelkI6lBynw16B+lUT1F96SERNDaMqi/TRsZdGe5mB/29ngC/QBMpEbRBLNRir5iUevKS7Pn4aph9Qjaxx/97siktK210FJT23KjHpgcUfjoQ6BgPBTLtEeQdRyDuc/CgfwIDAQABAoIBAGYDWOEpupQPSsZ4mjMnAYJwrp4ZISuMpEqVAORbhspVeb70bLKonT4IDcmiexCg7cQBcLQKGpPVM4CbQ0RFazXZPMVq470ZDeWDEyhoCfk3bGtdxc1Zc9CDxNMs6FeQs6r1beEZug6weG5J/yRn/qYxQife3qEuDMl+lzfl2EN3HYVOSnBmdt50dxRuX26iW3nqqbMRqYn9OHuJ1LvRRfYeyVKqgC5vgt/6Tf7DAJwGe0dD7q08byHV8DBZ0pnMVU0bYpf1GTgMibgjnLjK//EVWafFHtN+RXcjzGmyJrk3+7ZyPUpzpDjO21kpzUQLrpEkkBRnmg6bwHnSrBr8avECgYEA3pq1PTCAOuLQoIm1CWR9/dhkbJQiKTJevlWV8slXQLR50P0WvI2RdFuSxlWmA4xZej8s4e7iD3MYye6SBsQHygOVGc4efvvEZV8/XTlDdyj7iLVGhnEmu2r7AFKzy8cOvXx0QcLg+zNd7vxZv/8D3Qj9Jje2LjLHKM5n/dZ3RzUCgYEAzh5Lo2anc4WN8faLGt7rPkGQF+7/18ImQE11joHWa3LzAEy7FbeOGpE/vhOv5umq5M/KlWFIRahMEQv4RusieHWI19ZLIP+JwQFxWxS+cPp3xOiGcquSAZnlyVSxZ//dlVgaZq2o2MfrxECcovRlaknl2csyf+HjFFwKlNxHm2MCgYAr//R3BdEy0oZeVRndo2lr9YvUEmu2LOihQpWDCd0fQw0ZDA2kc28eysL2RROte95r1XTvq6IvX5a0w11FzRWlDpQ4J4/LlcQ6LVt+98SoFwew+/PWuyLmxLycUbyMOOpm9eSc4wJJZNvaUzMCSkvfMtmm5jgyZYMMQ9A2Ul/9SQKBgB9mfh9mhBwVPIqgBJETZMMXOdxrjI5SBYHGSyJqpT+5Q0vIZLfqPrvNZOiQFzwWXPJ+tV4Mc/YorW3rZOdo6tdvEGnRO6DLTTEaByrY/io3/gcBZXoSqSuVRmxleqFdWWRnB56c1hwwWLqNHU+1671FhL6pNghFYVK4suP6qu4BAoGBAMk+VipXcIlD67mfGrET/xDqiWWBZtgTzTMjTpODhDY1GZck1eb4CQMP5j5V3gFJ4cSgWDJvnWg8rcz0unz/q4aeMGl1rah5WNDWj1QKWMS6vJhMHM/rqN1WHWR0ZnV83svYgtg0zDnQKlLujqW4JmGXLMU7ur6a+e6lpa1fvLsP" + config.API.MaxRequestDuration = 5 * time.Second + + require.NoError(t, config.ApplyDefaults()) + require.NoError(t, config.SAML.PopulateFields(config.API.ExternalURL)) + + require.NotNil(t, config.SAML.Certificate) + + api := NewAPI(config, nil) + + // Setup request + req := httptest.NewRequest(http.MethodGet, "http://localhost/sso/saml/metadata", nil) + + w := httptest.NewRecorder() + api.handler.ServeHTTP(w, req) + require.Equal(t, w.Code, http.StatusOK) + + metadata := saml.EntityDescriptor{} + require.NoError(t, xml.Unmarshal(w.Body.Bytes(), &metadata)) + + require.Equal(t, metadata.EntityID, "https://projectref.supabase.co/auth/v1/sso/saml/metadata") + require.Equal(t, len(metadata.SPSSODescriptors), 1) + + require.Nil(t, metadata.SPSSODescriptors[0].AuthnRequestsSigned) + require.True(t, *(metadata.SPSSODescriptors[0].WantAssertionsSigned)) + + require.Equal(t, len(metadata.SPSSODescriptors[0].AssertionConsumerServices), 2) + require.Equal(t, metadata.SPSSODescriptors[0].AssertionConsumerServices[0].Location, "https://projectref.supabase.co/auth/v1/sso/saml/acs") + require.Equal(t, metadata.SPSSODescriptors[0].AssertionConsumerServices[1].Location, "https://projectref.supabase.co/auth/v1/sso/saml/acs") + require.Equal(t, len(metadata.SPSSODescriptors[0].SingleLogoutServices), 1) + require.Equal(t, metadata.SPSSODescriptors[0].SingleLogoutServices[0].Location, "https://projectref.supabase.co/auth/v1/sso/saml/slo") + + require.Equal(t, len(metadata.SPSSODescriptors[0].KeyDescriptors), 1) + require.Equal(t, metadata.SPSSODescriptors[0].KeyDescriptors[0].Use, "signing") + + require.Equal(t, len(metadata.SPSSODescriptors[0].NameIDFormats), 2) + require.Equal(t, metadata.SPSSODescriptors[0].NameIDFormats[0], saml.EmailAddressNameIDFormat) + require.Equal(t, metadata.SPSSODescriptors[0].NameIDFormats[1], saml.PersistentNameIDFormat) +} diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go new file mode 100644 index 000000000..a6dfc42aa --- /dev/null +++ b/internal/api/samlacs.go @@ -0,0 +1,328 @@ +package api + +import ( + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "net/http" + "net/url" + "time" + + "github.com/crewjam/saml" + "github.com/fatih/structs" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +func (a *API) samlDestroyRelayState(ctx context.Context, relayState *models.SAMLRelayState) error { + db := a.db.WithContext(ctx) + + // It's OK to destroy the RelayState, as a user will + // likely initiate a completely new login flow, instead + // of reusing the same one. + + return db.Transaction(func(tx *storage.Connection) error { + return tx.Destroy(relayState) + }) +} + +func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models.SAMLProvider) bool { + now := time.Now() + + hasValidityExpired := !idpMetadata.ValidUntil.IsZero() && now.After(idpMetadata.ValidUntil) + hasCacheDurationExceeded := idpMetadata.CacheDuration != 0 && now.After(samlProvider.UpdatedAt.Add(idpMetadata.CacheDuration)) + + // if metadata XML does not publish validity or caching information, update once in 24 hours + needsForceUpdate := idpMetadata.ValidUntil.IsZero() && idpMetadata.CacheDuration == 0 && now.After(samlProvider.UpdatedAt.Add(24*time.Hour)) + + return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate +} + +func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { + if err := a.handleSamlAcs(w, r); err != nil { + u, uerr := url.Parse(a.config.SiteURL) + if uerr != nil { + return internalServerError("site url is improperly formattted").WithInternalError(err) + } + + q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return nil +} + +// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior. +func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + db := a.db.WithContext(ctx) + config := a.config + log := observability.GetLogEntry(r).Entry + + relayStateValue := r.FormValue("RelayState") + relayStateUUID := uuid.FromStringOrNil(relayStateValue) + relayStateURL, _ := url.ParseRequestURI(relayStateValue) + + entityId := "" + initiatedBy := "" + redirectTo := "" + var requestIds []string + + var flowState *models.FlowState + if relayStateUUID != uuid.Nil { + // relay state is a valid UUID, therefore this is likely a SP initiated flow + + relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) + if models.IsNotFoundError(err) { + return notFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") + } else if err != nil { + return err + } + + if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod { + if err := a.samlDestroyRelayState(ctx, relayState); err != nil { + return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) + } + + return unprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") + } + + // TODO: add abuse detection to bind the RelayState UUID with a + // HTTP-Only cookie + + ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID) + if err != nil { + return internalServerError("Unable to find SSO Provider from SAML RelayState") + } + + initiatedBy = "sp" + entityId = ssoProvider.SAMLProvider.EntityID + redirectTo = relayState.RedirectTo + requestIds = append(requestIds, relayState.RequestID) + if relayState.FlowState != nil { + flowState = relayState.FlowState + } + + if err := a.samlDestroyRelayState(ctx, relayState); err != nil { + return err + } + } else if relayStateValue == "" || relayStateURL != nil { + // RelayState may be a URL in which case it's the URL where the + // IdP is telling us to redirect the user to + + if r.FormValue("SAMLart") != "" { + // SAML Artifact responses are possible only when + // RelayState can be used to identify the Identity + // Provider. + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") + } + + samlResponse := r.FormValue("SAMLResponse") + if samlResponse == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") + } + + responseXML, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") + } + + var peekResponse saml.Response + err = xml.Unmarshal(responseXML, &peekResponse) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) + } + + initiatedBy = "idp" + entityId = peekResponse.Issuer.Value + redirectTo = relayStateValue + } else { + // RelayState can't be identified, so SAML flow can't continue + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") + } + + ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) + if models.IsNotFoundError(err) { + return notFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") + } else if err != nil { + return err + } + + idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor() + if err != nil { + return err + } + + samlMetadataModified := false + + if ssoProvider.SAMLProvider.MetadataURL == nil { + if !idpMetadata.ValidUntil.IsZero() && time.Until(idpMetadata.ValidUntil) <= (30*24*60)*time.Second { + logentry := log.WithField("sso_provider_id", ssoProvider.ID.String()) + logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String()) + logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil) + logentry = logentry.WithField("saml_entity_id", ssoProvider.SAMLProvider.EntityID) + + logentry.Warn("SAML Metadata for identity provider will expire soon! Update its metadata_xml!") + } + } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, ssoProvider.SAMLProvider) { + rawMetadata, err := fetchSAMLMetadata(ctx, *ssoProvider.SAMLProvider.MetadataURL) + if err != nil { + // Fail silently but raise warning and continue with existing metadata + logentry := log.WithField("sso_provider_id", ssoProvider.ID.String()) + logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String()) + logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil) + logentry = logentry.WithError(err) + logentry.Warn("SAML Metadata could not be retrieved, continuing with existing metadata") + } else { + ssoProvider.SAMLProvider.MetadataXML = string(rawMetadata) + samlMetadataModified = true + } + } + + serviceProvider := a.getSAMLServiceProvider(idpMetadata, initiatedBy == "idp") + spAssertion, err := serviceProvider.ParseResponse(r, requestIds) + if err != nil { + if ire, ok := err.(*saml.InvalidResponseError); ok { + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) + } + + return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) + } + + assertion := SAMLAssertion{ + spAssertion, + } + + userID := assertion.UserID() + if userID == "" { + return badRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + } + + claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) + + email, ok := claims["email"].(string) + if !ok || email == "" { + // mapping does not identify the email attribute, try to figure it out + email = assertion.Email() + } + + if email == "" { + return badRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") + } else { + claims["email"] = email + } + + jsonClaims, err := json.Marshal(claims) + if err != nil { + return internalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) + } + + providerClaims := &provider.Claims{} + if err := json.Unmarshal(jsonClaims, providerClaims); err != nil { + return internalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) + } + + providerClaims.Subject = userID + providerClaims.Issuer = ssoProvider.SAMLProvider.EntityID + providerClaims.Email = email + providerClaims.EmailVerified = true + + providerClaimsMap := structs.Map(providerClaims) + + // remove all of the parsed claims, so that the rest can go into CustomClaims + for key := range providerClaimsMap { + delete(claims, key) + } + + providerClaims.CustomClaims = claims + + var userProvidedData provider.UserProvidedData + + userProvidedData.Emails = append(userProvidedData.Emails, provider.Email{ + Email: email, + Verified: true, + Primary: true, + }) + + // userProvidedData.Provider.Type = "saml" + // userProvidedData.Provider.ID = ssoProvider.ID.String() + // userProvidedData.Provider.SAMLEntityID = ssoProvider.SAMLProvider.EntityID + // userProvidedData.Provider.SAMLInitiatedBy = initiatedBy + + userProvidedData.Metadata = providerClaims + + // TODO: below + // refreshTokenParams.SSOProviderID = ssoProvider.ID + // refreshTokenParams.InitiatedByProvider = initiatedBy == "idp" + // refreshTokenParams.NotBefore = assertion.NotBefore() + // refreshTokenParams.NotAfter = assertion.NotAfter() + + notAfter := assertion.NotAfter() + + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + if !notAfter.IsZero() { + grantParams.SessionNotAfter = ¬After + } + + var token *AccessTokenResponse + if samlMetadataModified { + if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil { + return err + } + } + + if err := db.Transaction(func(tx *storage.Connection) error { + var terr error + var user *models.User + + // accounts potentially created via SAML can contain non-unique email addresses in the auth.users table + if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil { + return terr + } + if flowState != nil { + // This means that the callback is using PKCE + flowState.UserID = &(user.ID) + if terr := tx.Update(flowState); terr != nil { + return terr + } + } + + token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams) + + if terr != nil { + return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr) + } + + return nil + }); err != nil { + return err + } + + if !utilities.IsRedirectURLValid(config, redirectTo) { + redirectTo = config.SiteURL + } + if flowState != nil { + // This means that the callback is using PKCE + // Set the flowState.AuthCode to the query param here + redirectTo, err = a.prepPKCERedirectURL(redirectTo, flowState.AuthCode) + if err != nil { + return err + } + http.Redirect(w, r, redirectTo, http.StatusFound) + return nil + + } + http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) + + return nil +} diff --git a/internal/api/samlassertion.go b/internal/api/samlassertion.go new file mode 100644 index 000000000..fdf932385 --- /dev/null +++ b/internal/api/samlassertion.go @@ -0,0 +1,188 @@ +package api + +import ( + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/supabase/auth/internal/models" +) + +type SAMLAssertion struct { + *saml.Assertion +} + +const ( + SAMLSubjectIDAttributeName = "urn:oasis:names:tc:SAML:attribute:subject-id" +) + +// Attribute returns the first matching attribute value in the attribute +// statements where name equals the official SAML attribute Name or +// FriendlyName. Returns nil if such an attribute can't be found. +func (a *SAMLAssertion) Attribute(name string) []saml.AttributeValue { + var values []saml.AttributeValue + + for _, stmt := range a.AttributeStatements { + for _, attr := range stmt.Attributes { + if strings.EqualFold(attr.Name, name) || strings.EqualFold(attr.FriendlyName, name) { + values = append(values, attr.Values...) + } + } + } + + return values +} + +// UserID returns the best choice for a persistent user identifier on the +// Identity Provider side. Don't assume the format of the string returned, as +// it's Identity Provider specific. +func (a *SAMLAssertion) UserID() string { + // First we look up the SAMLSubjectIDAttributeName in the attribute + // section of the assertion, as this is the preferred way to + // persistently identify users in SAML 2.0. + // See: https://docs.oasis-open.org/security/saml-subject-id-attr/v1.0/cs01/saml-subject-id-attr-v1.0-cs01.html#_Toc536097226 + values := a.Attribute(SAMLSubjectIDAttributeName) + if len(values) > 0 { + return values[0].Value + } + + // Otherwise, fall back to the SubjectID value. + subjectID, isPersistent := a.SubjectID() + if !isPersistent { + return "" + } + + return subjectID +} + +// SubjectID returns the user identifier in present in the Subject section of +// the SAML assertion. Note that this way of identifying the Subject is +// generally superseded by the SAMLSubjectIDAttributeName assertion attribute; +// tho must be present in all assertions. It can have a few formats, of which +// the most important are: saml.EmailAddressNameIDFormat (meaning the user ID +// is an email address), saml.PersistentNameIDFormat (the user ID is an opaque +// string that does not change with each assertion, e.g. UUID), +// saml.TransientNameIDFormat (the user ID changes with each assertion -- can't +// be used to identify a user). The boolean returned identifies if the user ID +// is persistent. If it's an email address, it's lowercased just in case. +func (a *SAMLAssertion) SubjectID() (string, bool) { + if a.Subject == nil { + return "", false + } + + if a.Subject.NameID == nil { + return "", false + } + + if a.Subject.NameID.Value == "" { + return "", false + } + + if a.Subject.NameID.Format == string(saml.EmailAddressNameIDFormat) { + return strings.ToLower(strings.TrimSpace(a.Subject.NameID.Value)), true + } + + // all other NameID formats are regarded as persistent + isPersistent := a.Subject.NameID.Format != string(saml.TransientNameIDFormat) + + return a.Subject.NameID.Value, isPersistent +} + +// Email returns the best guess for an email address. +func (a *SAMLAssertion) Email() string { + attributeNames := []string{ + "urn:oid:0.9.2342.19200300.100.1.3", + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + "mail", + "Mail", + "email", + } + + for _, name := range attributeNames { + for _, attr := range a.Attribute(name) { + if attr.Value != "" { + return attr.Value + } + } + } + + if a.Subject.NameID.Format == string(saml.EmailAddressNameIDFormat) { + return a.Subject.NameID.Value + } + + return "" +} + +// Process processes this assertion according to the SAMLAttributeMapping. Never returns nil. +func (a *SAMLAssertion) Process(mapping models.SAMLAttributeMapping) map[string]interface{} { + ret := make(map[string]interface{}) + + for key, mapper := range mapping.Keys { + names := []string{} + if mapper.Name != "" { + names = append(names, mapper.Name) + } + names = append(names, mapper.Names...) + + setKey := false + + for _, name := range names { + for _, attr := range a.Attribute(name) { + if attr.Value != "" { + setKey = true + + if mapper.Array { + if ret[key] == nil { + ret[key] = []string{} + } + + ret[key] = append(ret[key].([]string), attr.Value) + } else { + ret[key] = attr.Value + break + } + } + } + + if setKey { + break + } + } + + if !setKey && mapper.Default != nil { + ret[key] = mapper.Default + } + } + + return ret +} + +// NotBefore extracts the time before which this assertion should not be +// considered. +func (a *SAMLAssertion) NotBefore() time.Time { + if a.Conditions != nil && !a.Conditions.NotBefore.IsZero() { + return a.Conditions.NotBefore.UTC() + } + + return time.Time{} +} + +// NotAfter extracts the time at which or after this assertion should not be +// considered. +func (a *SAMLAssertion) NotAfter() time.Time { + var notOnOrAfter time.Time + + for _, statement := range a.AuthnStatements { + if statement.SessionNotOnOrAfter == nil { + continue + } + + notOnOrAfter = *statement.SessionNotOnOrAfter + if !notOnOrAfter.IsZero() { + break + } + } + + return notOnOrAfter +} diff --git a/internal/api/samlassertion_test.go b/internal/api/samlassertion_test.go new file mode 100644 index 000000000..b7461b26d --- /dev/null +++ b/internal/api/samlassertion_test.go @@ -0,0 +1,347 @@ +package api + +import ( + tst "testing" + + "encoding/xml" + + "github.com/crewjam/saml" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func TestSAMLAssertionUserID(t *tst.T) { + type spec struct { + xml string + userID string + } + + examples := []spec{ + { + xml: ` + + https://example.com/saml + + + transient-name-id + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "", + }, + { + xml: ` + + https://example.com/saml + + + persistent-name-id + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "persistent-name-id", + }, + { + xml: ` + + https://example.com/saml + + + name-id@example.com + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "name-id@example.com", + }, + { + xml: ` + + https://example.com/saml + + + name-id@example.com + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + subject-id + + + +`, + userID: "subject-id", + }, + } + + for i, example := range examples { + rawAssertion := saml.Assertion{} + require.NoError(t, xml.Unmarshal([]byte(example.xml), &rawAssertion)) + + assertion := SAMLAssertion{ + &rawAssertion, + } + + userID := assertion.UserID() + + require.Equal(t, userID, example.userID, "example %d had different user ID", i) + } +} + +func TestSAMLAssertionProcessing(t *tst.T) { + type spec struct { + desc string + xml string + mapping models.SAMLAttributeMapping + expected map[string]interface{} + } + + examples := []spec{ + { + desc: "valid attribute and mapping", + xml: ` + + + + someone@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + }, + }, + { + desc: "valid attributes, use first attribute found in Names", + xml: ` + + + + old-soap@example.com + + + soap@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Names: []string{ + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + }, + }, + }, + }, + expected: map[string]interface{}{ + "email": "soap@example.com", + }, + }, + { + desc: "valid groups attribute", + xml: ` + + + + group1 + group2 + + + soap@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Names: []string{ + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + }, + }, + "groups": { + Name: "groups", + Array: true, + }, + }, + }, + expected: map[string]interface{}{ + "email": "soap@example.com", + "groups": []string{ + "group1", + "group2", + }, + }, + }, + { + desc: "missing attribute use default value", + xml: ` + + + + someone@example.com + + + +`, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "email", + }, + "role": { + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "member", + }, + }, + { + desc: "use default value even if attribute exists but is not specified in mapping", + xml: ` + + + + someone@example.com + + + admin + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + "role": { + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "member", + }, + }, + { + desc: "use value in XML when attribute exists and is specified in mapping", + xml: ` + + + + someone@example.com + + + admin + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + "role": { + Name: "role", + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "admin", + }, + }, + } + + for i, example := range examples { + t.Run(example.desc, func(t *tst.T) { + rawAssertion := saml.Assertion{} + require.NoError(t, xml.Unmarshal([]byte(example.xml), &rawAssertion)) + + assertion := SAMLAssertion{ + &rawAssertion, + } + + result := assertion.Process(example.mapping) + require.Equal(t, example.expected, result, "example %d had different processing", i) + }) + } +} diff --git a/internal/api/settings.go b/internal/api/settings.go new file mode 100644 index 000000000..bc2f38692 --- /dev/null +++ b/internal/api/settings.go @@ -0,0 +1,79 @@ +package api + +import "net/http" + +type ProviderSettings struct { + AnonymousUsers bool `json:"anonymous_users"` + Apple bool `json:"apple"` + Azure bool `json:"azure"` + Bitbucket bool `json:"bitbucket"` + Discord bool `json:"discord"` + Facebook bool `json:"facebook"` + Figma bool `json:"figma"` + Fly bool `json:"fly"` + GitHub bool `json:"github"` + GitLab bool `json:"gitlab"` + Google bool `json:"google"` + Keycloak bool `json:"keycloak"` + Kakao bool `json:"kakao"` + Linkedin bool `json:"linkedin"` + LinkedinOIDC bool `json:"linkedin_oidc"` + Notion bool `json:"notion"` + Spotify bool `json:"spotify"` + Slack bool `json:"slack"` + SlackOIDC bool `json:"slack_oidc"` + WorkOS bool `json:"workos"` + Twitch bool `json:"twitch"` + Twitter bool `json:"twitter"` + Email bool `json:"email"` + Phone bool `json:"phone"` + Zoom bool `json:"zoom"` +} + +type Settings struct { + ExternalProviders ProviderSettings `json:"external"` + DisableSignup bool `json:"disable_signup"` + MailerAutoconfirm bool `json:"mailer_autoconfirm"` + PhoneAutoconfirm bool `json:"phone_autoconfirm"` + SmsProvider string `json:"sms_provider"` + SAMLEnabled bool `json:"saml_enabled"` +} + +func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { + config := a.config + + return sendJSON(w, http.StatusOK, &Settings{ + ExternalProviders: ProviderSettings{ + AnonymousUsers: config.External.AnonymousUsers.Enabled, + Apple: config.External.Apple.Enabled, + Azure: config.External.Azure.Enabled, + Bitbucket: config.External.Bitbucket.Enabled, + Discord: config.External.Discord.Enabled, + Facebook: config.External.Facebook.Enabled, + Figma: config.External.Figma.Enabled, + Fly: config.External.Fly.Enabled, + GitHub: config.External.Github.Enabled, + GitLab: config.External.Gitlab.Enabled, + Google: config.External.Google.Enabled, + Kakao: config.External.Kakao.Enabled, + Keycloak: config.External.Keycloak.Enabled, + Linkedin: config.External.Linkedin.Enabled, + LinkedinOIDC: config.External.LinkedinOIDC.Enabled, + Notion: config.External.Notion.Enabled, + Spotify: config.External.Spotify.Enabled, + Slack: config.External.Slack.Enabled, + SlackOIDC: config.External.SlackOIDC.Enabled, + Twitch: config.External.Twitch.Enabled, + Twitter: config.External.Twitter.Enabled, + WorkOS: config.External.WorkOS.Enabled, + Email: config.External.Email.Enabled, + Phone: config.External.Phone.Enabled, + Zoom: config.External.Zoom.Enabled, + }, + DisableSignup: config.DisableSignup, + MailerAutoconfirm: config.Mailer.Autoconfirm, + PhoneAutoconfirm: config.Sms.Autoconfirm, + SmsProvider: config.Sms.Provider, + SAMLEnabled: config.SAML.Enabled, + }) +} diff --git a/api/settings_test.go b/internal/api/settings_test.go similarity index 62% rename from api/settings_test.go rename to internal/api/settings_test.go index c706bd48d..767bcf784 100644 --- a/api/settings_test.go +++ b/internal/api/settings_test.go @@ -11,7 +11,7 @@ import ( ) func TestSettings_DefaultProviders(t *testing.T) { - api, _, _, err := setupAPIForTestForInstance() + api, _, err := setupAPIForTest() require.NoError(t, err) // Setup request @@ -35,16 +35,22 @@ func TestSettings_DefaultProviders(t *testing.T) { require.True(t, p.Notion) require.True(t, p.Spotify) require.True(t, p.Slack) + require.True(t, p.SlackOIDC) require.True(t, p.Google) + require.True(t, p.Kakao) + require.True(t, p.Keycloak) require.True(t, p.Linkedin) + require.True(t, p.LinkedinOIDC) require.True(t, p.GitHub) require.True(t, p.GitLab) - require.True(t, p.SAML) require.True(t, p.Twitch) + require.True(t, p.WorkOS) + require.True(t, p.Zoom) + } func TestSettings_EmailDisabled(t *testing.T) { - api, config, instanceID, err := setupAPIForTestForInstance() + api, config, err := setupAPIForTest() require.NoError(t, err) config.External.Email.Enabled = false @@ -53,8 +59,7 @@ func TestSettings_EmailDisabled(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost/settings", nil) req.Header.Set("Content-Type", "application/json") - ctx, err := WithInstanceConfig(context.Background(), config, instanceID) - require.NoError(t, err) + ctx := context.Background() req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -66,27 +71,3 @@ func TestSettings_EmailDisabled(t *testing.T) { p := resp.ExternalProviders require.False(t, p.Email) } - -func TestSettings_ExternalName(t *testing.T) { - api, _, _, err := setupAPIForTestForInstance() - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "http://localhost/settings", nil) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - api.handler.ServeHTTP(w, req) - - require.Equal(t, w.Code, http.StatusOK) - - type SettingsWithExternalName struct { - ExternalLabels struct { - SAML string `json:"saml"` - } `json:"external_labels"` - } - resp := SettingsWithExternalName{} - err = json.NewDecoder(w.Body).Decode(&resp) - require.NoError(t, err) - - n := resp.ExternalLabels - require.Equal(t, n.SAML, "TestSamlName") -} diff --git a/internal/api/signup.go b/internal/api/signup.go new file mode 100644 index 000000000..f3f89a77c --- /dev/null +++ b/internal/api/signup.go @@ -0,0 +1,391 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/fatih/structs" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// SignupParams are the parameters the Signup endpoint accepts +type SignupParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + Password string `json:"password"` + Data map[string]interface{} `json:"data"` + Provider string `json:"-"` + Aud string `json:"-"` + Channel string `json:"channel"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { + config := a.config + + if p.Password == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Signup requires a valid password") + } + + if err := a.checkPasswordStrength(ctx, p.Password); err != nil { + return err + } + if p.Email != "" && p.Phone != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") + } + if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + } + // PKCE not needed as phone signups already return access token in body + if p.Phone != "" && p.CodeChallenge != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "PKCE not supported for phone signups") + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + + return nil +} + +func (p *SignupParams) ConfigureDefaults() { + if p.Email != "" { + p.Provider = "email" + } else if p.Phone != "" { + p.Provider = "phone" + } + if p.Data == nil { + p.Data = make(map[string]interface{}) + } + + // For backwards compatibility, we default to SMS if params Channel is not specified + if p.Phone != "" && p.Channel == "" { + p.Channel = sms_provider.SMSProvider + } +} + +func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err error) { + switch params.Provider { + case "email": + user, err = models.NewUser("", params.Email, params.Password, params.Aud, params.Data) + case "phone": + user, err = models.NewUser(params.Phone, "", params.Password, params.Aud, params.Data) + case "anonymous": + user, err = models.NewUser("", "", "", params.Aud, params.Data) + user.IsAnonymous = true + default: + // handles external provider case + user, err = models.NewUser("", params.Email, params.Password, params.Aud, params.Data) + } + if err != nil { + err = internalServerError("Database error creating user").WithInternalError(err) + return + } + user.IsSSOUser = isSSOUser + if user.AppMetaData == nil { + user.AppMetaData = make(map[string]interface{}) + } + + user.Identities = make([]models.Identity, 0) + + if params.Provider != "anonymous" { + // TODO: Deprecate "provider" field + user.AppMetaData["provider"] = params.Provider + + user.AppMetaData["providers"] = []string{params.Provider} + } + + return user, nil +} + +// Signup is the endpoint for registering a new user +func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + + if config.DisableSignup { + return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + params.ConfigureDefaults() + + if err := a.validateSignupParams(ctx, params); err != nil { + return err + } + + var err error + flowType := getFlowFromChallenge(params.CodeChallenge) + + var user *models.User + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + params.Aud = a.requestAud(ctx, r) + + switch params.Provider { + case "email": + if !config.External.Email.Enabled { + return badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email signups are disabled") + } + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) + case "phone": + if !config.External.Phone.Enabled { + return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone signups are disabled") + } + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) + default: + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(apierrors.ErrorCodeValidationFailed, msg) + } + + if err != nil && !models.IsNotFoundError(err) { + return internalServerError("Database error finding user").WithInternalError(err) + } + + var signupUser *models.User + if user == nil { + // always call this outside of a database transaction as this method + // can be computationally hard and block due to password hashing + signupUser, err = params.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if user != nil { + if (params.Provider == "email" && user.IsConfirmed()) || (params.Provider == "phone" && user.IsPhoneConfirmed()) { + return UserExistsError + } + // do not update the user because we can't be sure of their claimed identity + } else { + user, terr = a.signupNewUser(tx, signupUser) + if terr != nil { + return terr + } + } + identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), params.Provider) + if terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + identityData := structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + }) + for k, v := range params.Data { + if _, ok := identityData[k]; !ok { + identityData[k] = v + } + } + identity, terr = a.createNewIdentity(tx, user, params.Provider, identityData) + if terr != nil { + return terr + } + if terr := user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { + return terr + } + } + user.Identities = []models.Identity{*identity} + + if params.Provider == "email" && !user.IsConfirmed() { + if config.Mailer.Autoconfirm { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if terr = user.Confirm(tx); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if isPKCEFlow(flowType) { + _, terr := generateFlowState(tx, params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if terr != nil { + return terr + } + } + if terr = a.sendConfirmation(r, tx, user, flowType); terr != nil { + return terr + } + } + } else if params.Provider == "phone" && !user.IsPhoneConfirmed() { + if config.Sms.Autoconfirm { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": params.Provider, + "channel": params.Channel, + }); terr != nil { + return terr + } + if terr = user.ConfirmPhone(tx); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, params.Channel); terr != nil { + return terr + } + } + } + + return nil + }) + + if err != nil { + if errors.Is(err, UserExistsError) { + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { + return unprocessableEntityError(apierrors.ErrorCodeUserAlreadyExists, "User already registered") + } + sanitizedUser, err := sanitizeUser(user, params) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, sanitizedUser) + } + return err + } + + // handles case where Mailer.Autoconfirm is true or Phone.Autoconfirm is true + if user.IsConfirmed() || user.IsPhoneConfirmed() { + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) + + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin("password", user.ID) + return sendJSON(w, http.StatusOK, token) + } + if user.HasBeenInvited() { + // Remove sensitive fields + user.UserMetaData = map[string]interface{}{} + user.Identities = []models.Identity{} + } + return sendJSON(w, http.StatusOK, user) +} + +// sanitizeUser removes all user sensitive information from the user object +// Should be used whenever we want to prevent information about whether a user is registered or not from leaking +func sanitizeUser(u *models.User, params *SignupParams) (*models.User, error) { + now := time.Now() + + u.ID = uuid.Must(uuid.NewV4()) + + u.Role, u.EmailChange = "", "" + u.CreatedAt, u.UpdatedAt, u.ConfirmationSentAt = now, now, &now + u.LastSignInAt, u.ConfirmedAt, u.EmailChangeSentAt, u.EmailConfirmedAt, u.PhoneConfirmedAt = nil, nil, nil, nil, nil + u.Identities = make([]models.Identity, 0) + u.UserMetaData = params.Data + u.Aud = params.Aud + + // sanitize app_metadata + u.AppMetaData = map[string]interface{}{ + "provider": params.Provider, + "providers": []string{params.Provider}, + } + + // sanitize param fields + switch params.Provider { + case "email": + u.Phone = "" + case "phone": + u.Email = "" + default: + u.Phone, u.Email = "", "" + } + + return u, nil +} + +func (a *API) signupNewUser(conn *storage.Connection, user *models.User) (*models.User, error) { + config := a.config + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = tx.Create(user); terr != nil { + return internalServerError("Database error saving new user").WithInternalError(terr) + } + if terr = user.SetRole(tx, config.JWT.DefaultGroupName); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + + // there may be triggers or generated column values in the database that will modify the + // user data as it is being inserted. thus we load the user object + // again to fetch those changes. + if err := conn.Reload(user); err != nil { + return nil, internalServerError("Database error loading user after sign-up").WithInternalError(err) + } + + return user, nil +} diff --git a/internal/api/signup_test.go b/internal/api/signup_test.go new file mode 100644 index 000000000..3f4783261 --- /dev/null +++ b/internal/api/signup_test.go @@ -0,0 +1,153 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + mail "github.com/supabase/auth/internal/mailer" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type SignupTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestSignup(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &SignupTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *SignupTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) +} + +// TestSignup tests API /signup route +func (ts *SignupTestSuite) TestSignup() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.Equal(ts.T(), "test@example.com", data.GetEmail()) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) + assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) + assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) + assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) +} + +// TestSignupTwice checks to make sure the same email cannot be registered twice +func (ts *SignupTestSuite) TestSignupTwice() { + // Request body + var buffer bytes.Buffer + + encode := func() { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test1@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + } + + encode() + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + y := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(y, req) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test1@example.com", ts.Config.JWT.Aud) + if err == nil { + require.NoError(ts.T(), u.Confirm(ts.API.db)) + } + + encode() + ts.API.handler.ServeHTTP(w, req) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.NotEqual(ts.T(), u.ID, data.ID) + assert.Equal(ts.T(), "test1@example.com", data.GetEmail()) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) + assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) + assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) + assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) +} + +func (ts *SignupTestSuite) TestVerifySignup() { + user, err := models.NewUser("123456789", "test@example.com", "testing", ts.Config.JWT.Aud, nil) + user.ConfirmationToken = "asdf3" + now := time.Now() + user.ConfirmationSentAt = &now + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) + + // Find test user + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Setup request + reqUrl := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken) + req := httptest.NewRequest(http.MethodGet, reqUrl, nil) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + require.NoError(ts.T(), err) + v, err := url.ParseQuery(urlVal.Fragment) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), v.Get("access_token")) + require.NotEmpty(ts.T(), v.Get("expires_in")) + require.NotEmpty(ts.T(), v.Get("refresh_token")) +} diff --git a/api/sms_provider/messagebird.go b/internal/api/sms_provider/messagebird.go similarity index 69% rename from api/sms_provider/messagebird.go rename to internal/api/sms_provider/messagebird.go index 0055bad28..05f793903 100644 --- a/api/sms_provider/messagebird.go +++ b/internal/api/sms_provider/messagebird.go @@ -7,7 +7,8 @@ import ( "net/url" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" ) const ( @@ -24,6 +25,7 @@ type MessagebirdResponseRecipients struct { } type MessagebirdResponse struct { + ID string `json:"id"` Recipients MessagebirdResponseRecipients `json:"recipients"` } @@ -54,8 +56,17 @@ func NewMessagebirdProvider(config conf.MessagebirdProviderConfiguration) (SmsPr }, nil } +func (t *MessagebirdProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for Messagebird", channel) + } +} + // Send an SMS containing the OTP with Messagebird's API -func (t MessagebirdProvider) SendSms(phone string, message string) error { +func (t *MessagebirdProvider) SendSms(phone string, message string) (string, error) { body := url.Values{ "originator": {t.Config.Originator}, "body": {message}, @@ -64,37 +75,41 @@ func (t MessagebirdProvider) SendSms(phone string, message string) error { "datacoding": {"unicode"}, } - client := &http.Client{} + client := &http.Client{Timeout: defaultTimeout} r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) if err != nil { - return err + return "", err } r.Header.Add("Content-Type", "application/x-www-form-urlencoded") r.Header.Add("Authorization", "AccessKey "+t.Config.AccessKey) res, err := client.Do(r) if err != nil { - return err + return "", err } if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusForbidden || res.StatusCode == http.StatusUnauthorized || res.StatusCode == http.StatusUnprocessableEntity { resp := &MessagebirdErrResponse{} if err := json.NewDecoder(res.Body).Decode(resp); err != nil { - return err + return "", err } - return resp + return "", resp } - defer res.Body.Close() + defer utilities.SafeClose(res.Body) // validate sms status resp := &MessagebirdResponse{} derr := json.NewDecoder(res.Body).Decode(resp) if derr != nil { - return derr + return "", derr } if resp.Recipients.TotalSentCount == 0 { - return fmt.Errorf("Messagebird error: total sent count is 0") + return "", fmt.Errorf("messagebird error: total sent count is 0") } - return nil + return resp.ID, nil +} + +func (t *MessagebirdProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Messagebird") } diff --git a/internal/api/sms_provider/sms_provider.go b/internal/api/sms_provider/sms_provider.go new file mode 100644 index 000000000..103db4f95 --- /dev/null +++ b/internal/api/sms_provider/sms_provider.go @@ -0,0 +1,70 @@ +package sms_provider + +import ( + "fmt" + "log" + "os" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// overrides the SmsProvider set to always return the mock provider +var MockProvider SmsProvider = nil + +var defaultTimeout time.Duration = time.Second * 10 + +const SMSProvider = "sms" +const WhatsappProvider = "whatsapp" + +func init() { + timeoutStr := os.Getenv("GOTRUE_INTERNAL_HTTP_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_INTERNAL_HTTP_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } +} + +type SmsProvider interface { + SendMessage(phone, message, channel, otp string) (string, error) + VerifyOTP(phone, token string) error +} + +func GetSmsProvider(config conf.GlobalConfiguration) (SmsProvider, error) { + if MockProvider != nil { + return MockProvider, nil + } + + switch name := config.Sms.Provider; name { + case "twilio": + return NewTwilioProvider(config.Sms.Twilio) + case "messagebird": + return NewMessagebirdProvider(config.Sms.Messagebird) + case "textlocal": + return NewTextlocalProvider(config.Sms.Textlocal) + case "vonage": + return NewVonageProvider(config.Sms.Vonage) + case "twilio_verify": + return NewTwilioVerifyProvider(config.Sms.TwilioVerify) + default: + return nil, fmt.Errorf("sms Provider %s could not be found", name) + } +} + +func IsValidMessageChannel(channel string, config *conf.GlobalConfiguration) bool { + if config.Hook.SendSMS.Enabled { + // channel doesn't matter if SMS hook is enabled + return true + } + switch channel { + case SMSProvider: + return true + case WhatsappProvider: + return config.Sms.Provider == "twilio" || config.Sms.Provider == "twilio_verify" + default: + return false + } +} diff --git a/internal/api/sms_provider/sms_provider_test.go b/internal/api/sms_provider/sms_provider_test.go new file mode 100644 index 000000000..e5b5216a3 --- /dev/null +++ b/internal/api/sms_provider/sms_provider_test.go @@ -0,0 +1,287 @@ +package sms_provider + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type SmsProviderTestSuite struct { + suite.Suite + Config *conf.GlobalConfiguration +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestSmsProvider(t *testing.T) { + ts := &SmsProviderTestSuite{ + Config: &conf.GlobalConfiguration{ + Sms: conf.SmsProviderConfiguration{ + Twilio: conf.TwilioProviderConfiguration{ + AccountSid: "test_account_sid", + AuthToken: "test_auth_token", + MessageServiceSid: "test_message_service_id", + }, + TwilioVerify: conf.TwilioVerifyProviderConfiguration{ + AccountSid: "test_account_sid", + AuthToken: "test_auth_token", + MessageServiceSid: "test_message_service_id", + }, + Messagebird: conf.MessagebirdProviderConfiguration{ + AccessKey: "test_access_key", + Originator: "test_originator", + }, + Vonage: conf.VonageProviderConfiguration{ + ApiKey: "test_api_key", + ApiSecret: "test_api_secret", + From: "test_from", + }, + Textlocal: conf.TextlocalProviderConfiguration{ + ApiKey: "test_api_key", + Sender: "test_sender", + }, + }, + }, + } + suite.Run(t, ts) +} + +func (ts *SmsProviderTestSuite) TestTwilioSendSms() { + defer gock.Off() + provider, err := NewTwilioProvider(ts.Config.Sms.Twilio) + require.NoError(ts.T(), err) + + twilioProvider, ok := provider.(*TwilioProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "To": {"+" + phone}, + "Channel": {"sms"}, + "From": {twilioProvider.Config.MessageServiceSid}, + "Body": {message}, + } + + cases := []struct { + Desc string + TwilioResponse *gock.Response + ExpectedError error + OTP string + }{ + { + Desc: "Successfully sent sms", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + To: "+" + phone, + From: twilioProvider.Config.MessageServiceSid, + Status: "sent", + Body: message, + MessageSID: "abcdef", + }), + OTP: "123456", + ExpectedError: nil, + }, + { + Desc: "Sms status is failed / undelivered", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + ErrorMessage: "failed to send sms", + ErrorCode: "401", + Status: "failed", + MessageSID: "abcdef", + }), + ExpectedError: fmt.Errorf("twilio error: %v %v for message %v", "failed to send sms", "401", "abcdef"), + }, + { + Desc: "Non-2xx status code returned", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(500).JSON(twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }), + OTP: "123456", + ExpectedError: &twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + _, err = twilioProvider.SendSms(phone, message, SMSProvider, c.OTP) + require.Equal(ts.T(), c.ExpectedError, err) + }) + } +} + +func (ts *SmsProviderTestSuite) TestMessagebirdSendSms() { + defer gock.Off() + provider, err := NewMessagebirdProvider(ts.Config.Sms.Messagebird) + require.NoError(ts.T(), err) + + messagebirdProvider, ok := provider.(*MessagebirdProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + body := url.Values{ + "originator": {messagebirdProvider.Config.Originator}, + "body": {message}, + "recipients": {phone}, + "type": {"sms"}, + "datacoding": {"unicode"}, + } + gock.New(messagebirdProvider.APIPath).Post("").MatchHeader("Authorization", "AccessKey "+messagebirdProvider.Config.AccessKey).MatchType("url").BodyString(body.Encode()).Reply(200).JSON(MessagebirdResponse{ + Recipients: MessagebirdResponseRecipients{ + TotalSentCount: 1, + }, + }) + + _, err = messagebirdProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} + +func (ts *SmsProviderTestSuite) TestVonageSendSms() { + defer gock.Off() + provider, err := NewVonageProvider(ts.Config.Sms.Vonage) + require.NoError(ts.T(), err) + + vonageProvider, ok := provider.(*VonageProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "from": {vonageProvider.Config.From}, + "to": {phone}, + "text": {message}, + "api_key": {vonageProvider.Config.ApiKey}, + "api_secret": {vonageProvider.Config.ApiSecret}, + } + + gock.New(vonageProvider.APIPath).Post("").MatchType("url").BodyString(body.Encode()).Reply(200).JSON(VonageResponse{ + Messages: []VonageResponseMessage{ + {Status: "0"}, + }, + }) + + _, err = vonageProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} + +func (ts *SmsProviderTestSuite) TestTextLocalSendSms() { + defer gock.Off() + provider, err := NewTextlocalProvider(ts.Config.Sms.Textlocal) + require.NoError(ts.T(), err) + + textlocalProvider, ok := provider.(*TextlocalProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + body := url.Values{ + "sender": {textlocalProvider.Config.Sender}, + "apikey": {textlocalProvider.Config.ApiKey}, + "message": {message}, + "numbers": {phone}, + } + + gock.New(textlocalProvider.APIPath).Post("").MatchType("url").BodyString(body.Encode()).Reply(200).JSON(TextlocalResponse{ + Status: "success", + Errors: []TextlocalError{}, + }) + + _, err = textlocalProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} +func (ts *SmsProviderTestSuite) TestTwilioVerifySendSms() { + defer gock.Off() + provider, err := NewTwilioVerifyProvider(ts.Config.Sms.TwilioVerify) + require.NoError(ts.T(), err) + + twilioVerifyProvider, ok := provider.(*TwilioVerifyProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "To": {"+" + phone}, + "Channel": {"sms"}, + } + + cases := []struct { + Desc string + TwilioResponse *gock.Response + ExpectedError error + }{ + { + Desc: "Successfully sent sms", + TwilioResponse: gock.New(twilioVerifyProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioVerifyProvider.Config.AccountSid+":"+twilioVerifyProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + To: "+" + phone, + From: twilioVerifyProvider.Config.MessageServiceSid, + Status: "sent", + Body: message, + }), + ExpectedError: nil, + }, + { + Desc: "Non-2xx status code returned", + TwilioResponse: gock.New(twilioVerifyProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioVerifyProvider.Config.AccountSid+":"+twilioVerifyProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(500).JSON(twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }), + ExpectedError: &twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + _, err = twilioVerifyProvider.SendSms(phone, message, SMSProvider) + require.Equal(ts.T(), c.ExpectedError, err) + }) + } +} diff --git a/internal/api/sms_provider/textlocal.go b/internal/api/sms_provider/textlocal.go new file mode 100644 index 000000000..ef07a6f39 --- /dev/null +++ b/internal/api/sms_provider/textlocal.go @@ -0,0 +1,107 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + defaultTextLocalApiBase = "https://api.textlocal.in" + textLocalTemplateErrorCode = 80 +) + +type TextlocalProvider struct { + Config *conf.TextlocalProviderConfiguration + APIPath string +} + +type TextlocalError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TextlocalResponse struct { + Status string `json:"status"` + Errors []TextlocalError `json:"errors"` + Messages []TextlocalMessage `json:"messages"` +} + +type TextlocalMessage struct { + MessageID string `json:"id"` +} + +// Creates a SmsProvider with the Textlocal Config +func NewTextlocalProvider(config conf.TextlocalProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultTextLocalApiBase + "/send" + return &TextlocalProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TextlocalProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for TextLocal", channel) + } +} + +// Send an SMS containing the OTP with Textlocal's API +func (t *TextlocalProvider) SendSms(phone string, message string) (string, error) { + body := url.Values{ + "sender": {t.Config.Sender}, + "apikey": {t.Config.ApiKey}, + "message": {message}, + "numbers": {phone}, + } + + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + + resp := &TextlocalResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + messageID := "" + + if resp.Status != "success" { + if len(resp.Messages) > 0 { + messageID = resp.Messages[0].MessageID + } + + if len(resp.Errors) > 0 && resp.Errors[0].Code == textLocalTemplateErrorCode { + return messageID, fmt.Errorf("textlocal error: %v (code: %v) template message: %s", resp.Errors[0].Message, resp.Errors[0].Code, message) + } + + return messageID, fmt.Errorf("textlocal error: %v (code: %v) message %s", resp.Errors[0].Message, resp.Errors[0].Code, messageID) + } + + return messageID, nil +} +func (t *TextlocalProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Textlocal") +} diff --git a/internal/api/sms_provider/twilio.go b/internal/api/sms_provider/twilio.go new file mode 100644 index 000000000..3536c2f17 --- /dev/null +++ b/internal/api/sms_provider/twilio.go @@ -0,0 +1,141 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + defaultTwilioApiBase = "https://api.twilio.com" + apiVersion = "2010-04-01" +) + +type TwilioProvider struct { + Config *conf.TwilioProviderConfiguration + APIPath string +} + +var isPhoneNumber = regexp.MustCompile("^[1-9][0-9]{1,14}$") + +// formatPhoneNumber removes "+" and whitespaces in a phone number +func formatPhoneNumber(phone string) string { + return strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") +} + +type SmsStatus struct { + To string `json:"to"` + From string `json:"from"` + MessageSID string `json:"sid"` + Status string `json:"status"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` + Body string `json:"body"` +} + +type twilioErrResponse struct { + Code int `json:"code"` + Message string `json:"message"` + MoreInfo string `json:"more_info"` + Status int `json:"status"` +} + +func (t twilioErrResponse) Error() string { + return fmt.Sprintf("%s More information: %s", t.Message, t.MoreInfo) +} + +// Creates a SmsProvider with the Twilio Config +func NewTwilioProvider(config conf.TwilioProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultTwilioApiBase + "/" + apiVersion + "/" + "Accounts" + "/" + config.AccountSid + "/Messages.json" + return &TwilioProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TwilioProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider, WhatsappProvider: + return t.SendSms(phone, message, channel, otp) + default: + return "", fmt.Errorf("channel type %q is not supported for Twilio", channel) + } +} + +// Send an SMS containing the OTP with Twilio's API +func (t *TwilioProvider) SendSms(phone, message, channel, otp string) (string, error) { + sender := t.Config.MessageServiceSid + receiver := "+" + phone + body := url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Channel": {channel}, + "From": {sender}, + "Body": {message}, + } + if channel == WhatsappProvider { + receiver = channel + ":" + receiver + if isPhoneNumber.MatchString(formatPhoneNumber(sender)) { + sender = channel + ":" + sender + } + + // Programmable Messaging (WhatsApp) takes in different set of inputs + body = url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Channel": {channel}, + "From": {sender}, + } + // For backward compatibility with old API. + if t.Config.ContentSid != "" { + // Used to substitute OTP. See https://www.twilio.com/docs/content/whatsappauthentication for more details + contentVariables := fmt.Sprintf(`{"1": "%s"}`, otp) + body.Set("ContentSid", t.Config.ContentSid) + body.Set("ContentVariables", contentVariables) + } else { + body.Set("Body", message) + } + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return "", err + } + return "", resp + } + // validate sms status + resp := &SmsStatus{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + if resp.Status == "failed" || resp.Status == "undelivered" { + return resp.MessageSID, fmt.Errorf("twilio error: %v %v for message %s", resp.ErrorMessage, resp.ErrorCode, resp.MessageSID) + } + + return resp.MessageSID, nil +} +func (t *TwilioProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Twilio") +} diff --git a/internal/api/sms_provider/twilio_verify.go b/internal/api/sms_provider/twilio_verify.go new file mode 100644 index 000000000..8ec546396 --- /dev/null +++ b/internal/api/sms_provider/twilio_verify.go @@ -0,0 +1,139 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + verifyServiceApiBase = "https://verify.twilio.com/v2/Services/" +) + +type TwilioVerifyProvider struct { + Config *conf.TwilioVerifyProviderConfiguration + APIPath string +} + +type VerificationResponse struct { + To string `json:"to"` + Status string `json:"status"` + Channel string `json:"channel"` + Valid bool `json:"valid"` + VerificationSID string `json:"sid"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` +} + +// See: https://www.twilio.com/docs/verify/api/verification-check +type VerificationCheckResponse struct { + To string `json:"to"` + Status string `json:"status"` + Channel string `json:"channel"` + Valid bool `json:"valid"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` +} + +// Creates a SmsProvider with the Twilio Config +func NewTwilioVerifyProvider(config conf.TwilioVerifyProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + apiPath := verifyServiceApiBase + config.MessageServiceSid + "/Verifications" + + return &TwilioVerifyProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TwilioVerifyProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider, WhatsappProvider: + return t.SendSms(phone, message, channel) + default: + return "", fmt.Errorf("channel type %q is not supported for Twilio", channel) + } +} + +// Send an SMS containing the OTP with Twilio's API +func (t *TwilioVerifyProvider) SendSms(phone, message, channel string) (string, error) { + // Unlike Programmable Messaging, Verify does not require a prefix for channel + receiver := "+" + phone + body := url.Values{ + "To": {receiver}, + "Channel": {channel}, + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + if !(res.StatusCode == http.StatusOK || res.StatusCode == http.StatusCreated) { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return "", err + } + return "", resp + } + + resp := &VerificationResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + return resp.VerificationSID, nil +} + +func (t *TwilioVerifyProvider) VerifyOTP(phone, code string) error { + verifyPath := verifyServiceApiBase + t.Config.MessageServiceSid + "/VerificationCheck" + receiver := "+" + phone + + body := url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Code": {code}, + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", verifyPath, strings.NewReader(body.Encode())) + if err != nil { + return err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return err + } + defer utilities.SafeClose(res.Body) + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return err + } + return resp + } + resp := &VerificationCheckResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return derr + } + + if resp.Status != "approved" || !resp.Valid { + return fmt.Errorf("twilio verification error: %v %v", resp.ErrorMessage, resp.Status) + } + + return nil +} diff --git a/api/sms_provider/vonage.go b/internal/api/sms_provider/vonage.go similarity index 55% rename from api/sms_provider/vonage.go rename to internal/api/sms_provider/vonage.go index e266dbcc5..4b9fd5b74 100644 --- a/api/sms_provider/vonage.go +++ b/internal/api/sms_provider/vonage.go @@ -8,7 +8,9 @@ import ( "net/url" "strings" - "github.com/netlify/gotrue/conf" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/exp/utf8string" ) const ( @@ -21,6 +23,7 @@ type VonageProvider struct { } type VonageResponseMessage struct { + MessageID string `json:"message-id"` Status string `json:"status"` ErrorText string `json:"error-text"` } @@ -42,8 +45,17 @@ func NewVonageProvider(config conf.VonageProviderConfiguration) (SmsProvider, er }, nil } +func (t *VonageProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for Vonage", channel) + } +} + // Send an SMS containing the OTP with Vonage's API -func (t VonageProvider) SendSms(phone string, message string) error { +func (t *VonageProvider) SendSms(phone string, message string) (string, error) { body := url.Values{ "from": {t.Config.From}, "to": {phone}, @@ -52,33 +64,42 @@ func (t VonageProvider) SendSms(phone string, message string) error { "api_secret": {t.Config.ApiSecret}, } - client := &http.Client{} + isMessageContainUnicode := !utf8string.NewString(message).IsASCII() + if isMessageContainUnicode { + body.Set("type", "unicode") + } + + client := &http.Client{Timeout: defaultTimeout} r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) if err != nil { - return err + return "", err } r.Header.Add("Content-Type", "application/x-www-form-urlencoded") res, err := client.Do(r) if err != nil { - return err + return "", err } - defer res.Body.Close() + defer utilities.SafeClose(res.Body) resp := &VonageResponse{} derr := json.NewDecoder(res.Body).Decode(resp) if derr != nil { - return derr + return "", derr } if len(resp.Messages) <= 0 { - return errors.New("Vonage error: Internal Error") + return "", errors.New("vonage error: Internal Error") } // A status of zero indicates success; a non-zero value means something went wrong. if resp.Messages[0].Status != "0" { - return fmt.Errorf("Vonage error: %v (status: %v)", resp.Messages[0].ErrorText, resp.Messages[0].Status) + return resp.Messages[0].MessageID, fmt.Errorf("vonage error: %v (status: %v) for message %s", resp.Messages[0].ErrorText, resp.Messages[0].Status, resp.Messages[0].MessageID) } - return nil + return resp.Messages[0].MessageID, nil +} + +func (t *VonageProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Vonage") } diff --git a/api/sorting.go b/internal/api/sorting.go similarity index 95% rename from api/sorting.go rename to internal/api/sorting.go index d3999bd9f..f951d95f9 100644 --- a/api/sorting.go +++ b/internal/api/sorting.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - "github.com/netlify/gotrue/models" + "github.com/supabase/auth/internal/models" ) func sort(r *http.Request, allowedFields map[string]bool, defaultSort []models.SortField) (*models.SortParams, error) { diff --git a/internal/api/sso.go b/internal/api/sso.go new file mode 100644 index 000000000..39304fe78 --- /dev/null +++ b/internal/api/sso.go @@ -0,0 +1,148 @@ +package api + +import ( + "net/http" + + "github.com/crewjam/saml" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +type SingleSignOnParams struct { + ProviderID uuid.UUID `json:"provider_id"` + Domain string `json:"domain"` + RedirectTo string `json:"redirect_to"` + SkipHTTPRedirect *bool `json:"skip_http_redirect"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +type SingleSignOnResponse struct { + URL string `json:"url"` +} + +func (p *SingleSignOnParams) validate() (bool, error) { + hasProviderID := p.ProviderID != uuid.Nil + hasDomain := p.Domain != "" + + if hasProviderID && hasDomain { + return hasProviderID, badRequestError(apierrors.ErrorCodeValidationFailed, "Only one of provider_id or domain supported") + } else if !hasProviderID && !hasDomain { + return hasProviderID, badRequestError(apierrors.ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") + } + + return hasProviderID, nil +} + +// SingleSignOn handles the single-sign-on flow for a provided SSO domain or provider. +func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &SingleSignOnParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + hasProviderID := false + + if hasProviderID, err = params.validate(); err != nil { + return err + } + codeChallengeMethod := params.CodeChallengeMethod + codeChallenge := params.CodeChallenge + + if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { + return err + } + flowType := getFlowFromChallenge(params.CodeChallenge) + var flowStateID *uuid.UUID + flowStateID = nil + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) + if err != nil { + return err + } + flowStateID = &flowState.ID + } + + var ssoProvider *models.SSOProvider + + if hasProviderID { + ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) + if models.IsNotFoundError(err) { + return notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No such SSO provider") + } else if err != nil { + return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) + } + } else { + ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) + if models.IsNotFoundError(err) { + return notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") + } else if err != nil { + return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) + } + } + + entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor() + if err != nil { + return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) + } + + serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */) + + authnRequest, err := serviceProvider.MakeAuthenticationRequest( + serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding), + saml.HTTPRedirectBinding, + saml.HTTPPostBinding, + ) + if err != nil { + return internalServerError("Error creating SAML Authentication Request").WithInternalError(err) + } + + // Some IdPs do not support the use of the `persistent` NameID format, + // and require a different format to be sent to work. + if ssoProvider.SAMLProvider.NameIDFormat != nil { + authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat + } + + relayState := models.SAMLRelayState{ + SSOProviderID: ssoProvider.ID, + RequestID: authnRequest.ID, + RedirectTo: params.RedirectTo, + FlowStateID: flowStateID, + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(&relayState); terr != nil { + return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + } + + return nil + }); err != nil { + return err + } + + ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider) + if err != nil { + return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) + } + + skipHTTPRedirect := false + + if params.SkipHTTPRedirect != nil { + skipHTTPRedirect = *params.SkipHTTPRedirect + } + + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, SingleSignOnResponse{ + URL: ssoRedirectURL.String(), + }) + } + + http.Redirect(w, r, ssoRedirectURL.String(), http.StatusSeeOther) + return nil +} diff --git a/internal/api/sso_test.go b/internal/api/sso_test.go new file mode 100644 index 000000000..bae1bebf3 --- /dev/null +++ b/internal/api/sso_test.go @@ -0,0 +1,752 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +const dateInPast = "2001-02-03T04:05:06.789" +const dateInFarFuture = "2999-02-03T04:05:06.789" +const oneHour = "PT1H" + +type SSOTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + AdminJWT string +} + +func TestSSO(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &SSOTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + if config.SAML.Enabled { + suite.Run(t, ts) + } +} + +func (ts *SSOTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + + ts.AdminJWT = token +} + +func (ts *SSOTestSuite) TestNonAdminJWT() { + // TODO +} + +func (ts *SSOTestSuite) TestAdminListEmptySSOProviders() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers", nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + body, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var result struct { + Items []interface{} `json:"items"` + NextToken string `json:"next_token"` + } + + require.NoError(ts.T(), json.Unmarshal(body, &result)) + require.Equal(ts.T(), len(result.Items), 0) + require.Equal(ts.T(), result.NextToken, "") +} + +func (ts *SSOTestSuite) TestAdminGetSSOProviderNotExist() { + examples := []struct { + URL string + }{ + { + URL: "http://localhost/admin/sso/providers/not-a-uuid", + }, + { + URL: "http://localhost/admin/sso/providers/677477db-3f51-4038-bc05-c6bb9bdc3c32", + }, + } + + for _, example := range examples { + req := httptest.NewRequest(http.MethodGet, example.URL, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + } +} + +func configurableSAMLIDPMetadata(entityID, validUntil, cacheDuration string) string { + return fmt.Sprintf(` + + + + + MIIDdDCCAlygAwIBAgIGAYKSjRZiMA0GCSqGSIb3DQEBCwUAMHsxFDASBgNVBAoTC0dvb2dsZSBJ +bmMuMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MQ8wDQYDVQQDEwZHb29nbGUxGDAWBgNVBAsTD0dv +b2dsZSBGb3IgV29yazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEwHhcNMjIwODEy +MTQ1NDU1WhcNMjcwODExMTQ1NDU1WjB7MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEWMBQGA1UEBxMN +TW91bnRhaW4gVmlldzEPMA0GA1UEAxMGR29vZ2xlMRgwFgYDVQQLEw9Hb29nbGUgRm9yIFdvcmsx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAlncFzErcnZm7ZWO71NZStnCIAoYNKf6Uw3LPLzcvk0YrA/eBC3PVDHSfahi+apGO +Ytdq7IQUvBdto3rJTvP49fjyO0WLbAbiPC+dILt2Gx9kttxpSp99Bf+8ObL/fTy5Y2oHbJBfBX1V +qfDQIY0fcej3AndFYUOE0gZXyeSbnROB8W1PzHxOc7rq1mlas0rvyja7AK4gwXjIwyIGsFDmHnve +buqWOYMzOT9oD+iQq9BWYVHkXGZn0BXzKtnw9w8I3IxQdndUoCl95pYRIvdl1b0dWdO9cXtSsTkL +kAa8B/mCQcF4W2M3t/yKtrcLcRTALg3/Hc+Xz+3BpY/fSDk1SwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQCER02WLf6bKwTGVD/3VTntetIiETuPs46Dum8blbsg+2BYdAHIQcB9cLuMRosIw0nYj54m +SfiyfoWGcx3CkMup1MtKyWu+SqDHl9Bpf+GFLG0ngKD/zB6xwpv/TCi+g/FBYe2TvzD6B1V0z7Vs +Xf+Gc2TWBKmCuKf/g2AUt7IQLpOaqxuJVoZjp4sEMov6d3FnaoHQEd0lg+XmnYfLNtwe3QRSU0BD +x6lVV4kXi0x0n198/gkjnA85rPZoZ6dmqHtkcM0Gabgg6KEE5ubSDlWDsdv27uANceCZAoxd1+in +4/KqqkhynnbJs7Op5ZX8cckiHGGTGHNb35kys/XukuCo + + + + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +`, entityID, validUntil, cacheDuration, entityID, entityID) + +} + +func (ts *SSOTestSuite) TestIsStaleSAMLMetadata() { + + // https://en.wikipedia.org/wiki/ISO_8601 + currentTime := time.Now() + currentTimeAsISO8601 := currentTime.UTC().Format("2006-01-02T15:04:05Z07:00") + examples := []struct { + Description string + Metadata []byte + IsStale bool + CacheDurationExceeded bool + }{ + { + Description: "Metadata is valid and within cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", dateInFarFuture, oneHour)), + IsStale: false, + CacheDurationExceeded: false, + }, + { + + Description: "Metadata is valid but is a minute past cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", currentTimeAsISO8601, oneHour)), + IsStale: true, + CacheDurationExceeded: true, + }, + + { + Description: "Metadata is invalid but within cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", dateInPast, oneHour)), + IsStale: true, + CacheDurationExceeded: false, + }, + } + + for _, example := range examples { + metadata, err := parseSAMLMetadata(example.Metadata) + require.NoError(ts.T(), err) + provider := models.SAMLProvider{ + EntityID: metadata.EntityID, + MetadataXML: string(example.Metadata), + UpdatedAt: currentTime, + } + if example.CacheDurationExceeded { + provider.UpdatedAt = currentTime.Add(-time.Minute * 59) + } + + require.Equal(ts.T(), example.IsStale, IsSAMLMetadataStale(metadata, provider)) + } + +} + +func validSAMLIDPMetadata(entityID string) string { + return fmt.Sprintf(` + + + + + MIIDdDCCAlygAwIBAgIGAYKSjRZiMA0GCSqGSIb3DQEBCwUAMHsxFDASBgNVBAoTC0dvb2dsZSBJ +bmMuMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MQ8wDQYDVQQDEwZHb29nbGUxGDAWBgNVBAsTD0dv +b2dsZSBGb3IgV29yazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEwHhcNMjIwODEy +MTQ1NDU1WhcNMjcwODExMTQ1NDU1WjB7MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEWMBQGA1UEBxMN +TW91bnRhaW4gVmlldzEPMA0GA1UEAxMGR29vZ2xlMRgwFgYDVQQLEw9Hb29nbGUgRm9yIFdvcmsx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAlncFzErcnZm7ZWO71NZStnCIAoYNKf6Uw3LPLzcvk0YrA/eBC3PVDHSfahi+apGO +Ytdq7IQUvBdto3rJTvP49fjyO0WLbAbiPC+dILt2Gx9kttxpSp99Bf+8ObL/fTy5Y2oHbJBfBX1V +qfDQIY0fcej3AndFYUOE0gZXyeSbnROB8W1PzHxOc7rq1mlas0rvyja7AK4gwXjIwyIGsFDmHnve +buqWOYMzOT9oD+iQq9BWYVHkXGZn0BXzKtnw9w8I3IxQdndUoCl95pYRIvdl1b0dWdO9cXtSsTkL +kAa8B/mCQcF4W2M3t/yKtrcLcRTALg3/Hc+Xz+3BpY/fSDk1SwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQCER02WLf6bKwTGVD/3VTntetIiETuPs46Dum8blbsg+2BYdAHIQcB9cLuMRosIw0nYj54m +SfiyfoWGcx3CkMup1MtKyWu+SqDHl9Bpf+GFLG0ngKD/zB6xwpv/TCi+g/FBYe2TvzD6B1V0z7Vs +Xf+Gc2TWBKmCuKf/g2AUt7IQLpOaqxuJVoZjp4sEMov6d3FnaoHQEd0lg+XmnYfLNtwe3QRSU0BD +x6lVV4kXi0x0n198/gkjnA85rPZoZ6dmqHtkcM0Gabgg6KEE5ubSDlWDsdv27uANceCZAoxd1+in +4/KqqkhynnbJs7Op5ZX8cckiHGGTGHNb35kys/XukuCo + + + + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +`, entityID, entityID, entityID) +} + +func (ts *SSOTestSuite) TestAdminCreateSSOProvider() { + examples := []struct { + StatusCode int + Request map[string]interface{} + }{ + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{}, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "oidc", + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-ATTRIBUTE-MAPPING"), + "attribute_mapping": map[string]interface{}{ + "keys": map[string]interface{}{ + "username": map[string]interface{}{ + "name": "mail", + }, + }, + }, + }, + }, + { + StatusCode: http.StatusUnprocessableEntity, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-DOMAIN-A"), + "domains": []string{ + "example.com", + }, + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-DOMAIN-B"), + "domains": []string{ + "example.com", + }, + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-URL-TOO", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-URL-TOO"), + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "http://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-OVER-HTTP", + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "https://accounts.google.com\\o/saml2?idpid=EXAMPLE-WITH-INVALID-METADATA-URL", + }, + }, + // TODO: add example with metadata_url + } + + for i, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), example.StatusCode, w.Code, "Example %d failed with body %q", i, response) + + if example.StatusCode != http.StatusCreated { + continue + } + + // now check if the provider can be queried (GET) + var provider struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &provider)) + + req = httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers/"+provider.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err = io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + originalProviderID := provider.ID + provider.ID = "" + + require.NoError(ts.T(), json.Unmarshal(response, &provider)) + require.Equal(ts.T(), provider.ID, originalProviderID) + + // now check if the provider can be queried (List) + var providers struct { + Items []struct { + ID string `json:"id"` + } `json:"items"` + } + + req = httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers", nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err = io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.NoError(ts.T(), json.Unmarshal(response, &providers)) + + contained := false + for _, listProvider := range providers.Items { + if listProvider.ID == provider.ID { + contained = true + break + } + } + + require.True(ts.T(), contained) + } +} + +func (ts *SSOTestSuite) TestAdminUpdateSSOProvider() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-C"), + "domains": []string{ + "example.com", + }, + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + ID string + Status int + Request map[string]interface{} + }{ + { + ID: providers[0].ID, + Status: http.StatusBadRequest, // changing entity ID + Request: map[string]interface{}{ + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + { + ID: providers[0].ID, + Status: http.StatusBadRequest, // domain already exists + Request: map[string]interface{}{ + "domains": []string{ + "example.com", + }, + }, + }, + { + ID: providers[1].ID, + Status: http.StatusOK, + Request: map[string]interface{}{ + "domains": []string{ + "example.com", + "example.org", + }, + }, + }, + { + ID: providers[1].ID, + Status: http.StatusOK, + Request: map[string]interface{}{ + "attribute_mapping": map[string]interface{}{ + "keys": map[string]interface{}{ + "username": map[string]interface{}{ + "name": "mail", + }, + }, + }, + }, + }, + } + + for _, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/admin/sso/providers/"+example.ID, bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Status) + } +} + +func (ts *SSOTestSuite) TestAdminDeleteSSOProvider() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + ID string + Status int + }{ + { + ID: providers[0].ID, + Status: http.StatusOK, + }, + } + + for _, example := range examples { + req := httptest.NewRequest(http.MethodDelete, "http://localhost/admin/sso/providers/"+example.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Status) + } + + check := []struct { + ID string + }{ + { + ID: providers[0].ID, + }, + } + + for _, example := range check { + req := httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers/"+example.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + } +} + +func (ts *SSOTestSuite) TestSingleSignOn() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + // creates a SAML provider (EXAMPLE-A) + // does not have a domain mapping + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + // creates a SAML provider (EXAMPLE-B) + // does have a domain mapping on example.com + Request: map[string]interface{}{ + "type": "saml", + "domains": []string{ + "example.com", + }, + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + Code int + Request map[string]interface{} + URL string + }{ + { + // call /sso with provider_id (EXAMPLE-A) + // should be successful and redirect to the EXAMPLE-A SSO URL + Request: map[string]interface{}{ + "provider_id": providers[0].ID, + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-A", + }, + { + // call /sso with provider_id (EXAMPLE-A) and SSO PKCE + // should be successful and redirect to the EXAMPLE-A SSO URL + Request: map[string]interface{}{ + "provider_id": providers[0].ID, + "code_challenge": "vby3iMQ4XUuycKkEyNsYHXshPql1Dod7Ebey2iXTXm4", + "code_challenge_method": "s256", + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-A", + }, + { + // call /sso with domain=example.com (provider=EXAMPLE-B) + // should be successful and redirect to the EXAMPLE-B SSO URL + Request: map[string]interface{}{ + "domain": "example.com", + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", + }, + { + // call /sso with domain=example.com (provider=EXAMPLE-B) + // should be successful and redirect to the EXAMPLE-B SSO URL + Request: map[string]interface{}{ + "domain": "example.com", + "skip_http_redirect": true, + }, + Code: http.StatusOK, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", + }, + { + // call /sso with domain=example.org (no such provider) + // should be unsuccessful with 404 + Request: map[string]interface{}{ + "domain": "example.org", + }, + Code: http.StatusNotFound, + }, + { + // call /sso with a provider_id= (no such provider) + // should be unsuccessful with 404 + Request: map[string]interface{}{ + "provider_id": "14d906bf-9bd5-4734-b7d1-3904e240610e", + }, + Code: http.StatusNotFound, + }, + } + + for _, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/sso", bytes.NewBuffer(body)) + // no authorization header intentional, this is a login endpoint + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Code) + + locationURLString := "" + + if example.Code == http.StatusSeeOther { + locationURLString = w.Header().Get("Location") + } else if example.Code == http.StatusOK { + var response struct { + URL string `json:"url"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + require.NotEmpty(ts.T(), response.URL) + + locationURLString = response.URL + } else { + continue + } + + locationURL, err := url.ParseRequestURI(locationURLString) + require.NoError(ts.T(), err) + + locationQuery, err := url.ParseQuery(locationURL.RawQuery) + + require.NoError(ts.T(), err) + + samlQueryParams := []string{ + "SAMLRequest", + "RelayState", + "SigAlg", + "Signature", + } + + for _, param := range samlQueryParams { + require.True(ts.T(), locationQuery.Has(param)) + } + + for _, param := range samlQueryParams { + locationQuery.Del(param) + } + + locationURL.RawQuery = locationQuery.Encode() + + require.Equal(ts.T(), locationURL.String(), example.URL) + } +} + +func TestSSOCreateParamsValidation(t *testing.T) { + // TODO +} diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go new file mode 100644 index 000000000..fff2813c7 --- /dev/null +++ b/internal/api/ssoadmin.go @@ -0,0 +1,422 @@ +package api + +import ( + "context" + "io" + "net/http" + "net/url" + "strings" + "unicode/utf8" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider +// with that ID (or resource ID) and adds it to the context. +func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + idpParam := chi.URLParam(r, "idp_id") + + idpID, err := uuid.FromString(idpParam) + if err != nil { + // idpParam is not UUIDv4 + return nil, notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } + + // idpParam is a UUIDv4 + provider, err := models.FindSSOProviderByID(db, idpID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } else { + return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) + } + } + + observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String()) + + return withSSOProvider(r.Context(), provider), nil +} + +// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does +// not deal with pagination at this time. +func (a *API) adminSSOProvidersList(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + providers, err := models.FindAllSAMLProviders(db) + if err != nil { + return err + } + + for i := range providers { + // remove metadata XML so that the returned JSON is not ginormous + providers[i].SAMLProvider.MetadataXML = "" + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "items": providers, + }) +} + +type CreateSSOProviderParams struct { + Type string `json:"type"` + + MetadataURL string `json:"metadata_url"` + MetadataXML string `json:"metadata_xml"` + Domains []string `json:"domains"` + AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"` + NameIDFormat string `json:"name_id_format"` +} + +func (p *CreateSSOProviderParams) validate(forUpdate bool) error { + if !forUpdate && p.Type != "saml" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") + } else if p.MetadataURL != "" && p.MetadataXML != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") + } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") + } else if p.MetadataURL != "" { + metadataURL, err := url.ParseRequestURI(p.MetadataURL) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a valid URL") + } + + if metadataURL.Scheme != "https" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") + } + } + + switch p.NameIDFormat { + case "", + string(saml.PersistentNameIDFormat), + string(saml.EmailAddressNameIDFormat), + string(saml.TransientNameIDFormat), + string(saml.UnspecifiedNameIDFormat): + // it's valid + + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{ + string(saml.PersistentNameIDFormat), + string(saml.EmailAddressNameIDFormat), + string(saml.TransientNameIDFormat), + string(saml.UnspecifiedNameIDFormat), + }, ", ")) + } + + return nil +} + +func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.EntityDescriptor, error) { + var rawMetadata []byte + var err error + + if p.MetadataXML != "" { + rawMetadata = []byte(p.MetadataXML) + } else if p.MetadataURL != "" { + rawMetadata, err = fetchSAMLMetadata(ctx, p.MetadataURL) + if err != nil { + return nil, nil, err + } + } else { + // impossible situation if you called validate() prior + return nil, nil, nil + } + + metadata, err := parseSAMLMetadata(rawMetadata) + if err != nil { + return nil, nil, err + } + + return rawMetadata, metadata, nil +} + +func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { + if !utf8.Valid(rawMetadata) { + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + } + + metadata, err := samlsp.ParseMetadata(rawMetadata) + if err != nil { + return nil, err + } + + if metadata.EntityID == "" { + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") + } + + if len(metadata.IDPSSODescriptors) < 1 { + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") + } + + if len(metadata.IDPSSODescriptors) > 1 { + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") + } + + return metadata, nil +} + +func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) + } + + req = req.WithContext(ctx) + + req.Header.Set("Accept", "application/xml;charset=UTF-8") + req.Header.Set("Accept-Charset", "UTF-8") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + defer utilities.SafeClose(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, badRequestError(apierrors.ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return data, nil +} + +// adminSSOProvidersCreate creates a new SAML Identity Provider in the system. +func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.validate(false /* <- forUpdate */); err != nil { + return err + } + + rawMetadata, metadata, err := params.metadata(ctx) + if err != nil { + return err + } + + existingProvider, err := models.FindSAMLProviderByEntityID(db, metadata.EntityID) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + return unprocessableEntityError(apierrors.ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + } + + provider := &models.SSOProvider{ + // TODO handle Name, Description, Attribute Mapping + SAMLProvider: models.SAMLProvider{ + EntityID: metadata.EntityID, + MetadataXML: string(rawMetadata), + }, + } + + if params.MetadataURL != "" { + provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL + } + + if params.NameIDFormat != "" { + provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat + } + + provider.SAMLProvider.AttributeMapping = params.AttributeMapping + + for _, domain := range params.Domains { + existingProvider, err := models.FindSSOProviderByDomain(db, domain) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + return badRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + } + + provider.SSODomains = append(provider.SSODomains, models.SSODomain{ + Domain: domain, + }) + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Eager().Create(provider); terr != nil { + return terr + } + + return tx.Eager().Load(provider) + }); err != nil { + return err + } + + return sendJSON(w, http.StatusCreated, provider) +} + +// adminSSOProvidersGet returns an existing SAML Identity Provider in the system. +func (a *API) adminSSOProvidersGet(w http.ResponseWriter, r *http.Request) error { + provider := getSSOProvider(r.Context()) + + return sendJSON(w, http.StatusOK, provider) +} + +// adminSSOProvidersUpdate updates a provider with the provided diff values. +func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.validate(true /* <- forUpdate */); err != nil { + return err + } + + modified := false + updateSAMLProvider := false + + provider := getSSOProvider(ctx) + + if params.MetadataXML != "" || params.MetadataURL != "" { + // metadata is being updated + rawMetadata, metadata, err := params.metadata(ctx) + if err != nil { + return err + } + + if provider.SAMLProvider.EntityID != metadata.EntityID { + return badRequestError(apierrors.ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + } + + if params.MetadataURL != "" { + provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL + } + + provider.SAMLProvider.MetadataXML = string(rawMetadata) + updateSAMLProvider = true + modified = true + } + + // domains are being "updated" only when params.Domains is not nil, if + // it was nil (but not `[]`) then the caller is expecting not to modify + // the domains + updateDomains := params.Domains != nil + + var createDomains, deleteDomains []models.SSODomain + keepDomains := make(map[string]bool) + + for _, domain := range params.Domains { + existingProvider, err := models.FindSSOProviderByDomain(db, domain) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + if existingProvider.ID == provider.ID { + keepDomains[domain] = true + } else { + return badRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + } + } else { + modified = true + createDomains = append(createDomains, models.SSODomain{ + Domain: domain, + SSOProviderID: provider.ID, + }) + } + } + + if updateDomains { + for i, domain := range provider.SSODomains { + if !keepDomains[domain.Domain] { + modified = true + deleteDomains = append(deleteDomains, provider.SSODomains[i]) + } + } + } + + updateAttributeMapping := false + if params.AttributeMapping.Keys != nil { + updateAttributeMapping = !provider.SAMLProvider.AttributeMapping.Equal(¶ms.AttributeMapping) + if updateAttributeMapping { + modified = true + provider.SAMLProvider.AttributeMapping = params.AttributeMapping + } + } + + nameIDFormat := "" + if provider.SAMLProvider.NameIDFormat != nil { + nameIDFormat = *provider.SAMLProvider.NameIDFormat + } + + if params.NameIDFormat != nameIDFormat { + modified = true + + if params.NameIDFormat == "" { + provider.SAMLProvider.NameIDFormat = nil + } else { + provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat + } + } + + if modified { + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Eager().Update(provider); terr != nil { + return terr + } + + if updateDomains { + if terr := tx.Destroy(deleteDomains); terr != nil { + return terr + } + + if terr := tx.Eager().Create(createDomains); terr != nil { + return terr + } + } + + if updateAttributeMapping || updateSAMLProvider { + if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil { + return terr + } + } + + return tx.Eager().Load(provider) + }); err != nil { + return unprocessableEntityError(apierrors.ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + } + } + + return sendJSON(w, http.StatusOK, provider) +} + +// adminSSOProvidersDelete deletes a SAML identity provider. +func (a *API) adminSSOProvidersDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + + if err := db.Transaction(func(tx *storage.Connection) error { + return tx.Eager().Destroy(provider) + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, provider) +} diff --git a/internal/api/token.go b/internal/api/token.go new file mode 100644 index 000000000..098d90640 --- /dev/null +++ b/internal/api/token.go @@ -0,0 +1,509 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/gofrs/uuid" + "github.com/golang-jwt/jwt/v5" + "github.com/xeipuuv/gojsonschema" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +// AccessTokenClaims is a struct thats used for JWT claims +type AccessTokenClaims struct { + jwt.RegisteredClaims + Email string `json:"email"` + Phone string `json:"phone"` + AppMetaData map[string]interface{} `json:"app_metadata"` + UserMetaData map[string]interface{} `json:"user_metadata"` + Role string `json:"role"` + AuthenticatorAssuranceLevel string `json:"aal,omitempty"` + AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` + SessionId string `json:"session_id,omitempty"` + IsAnonymous bool `json:"is_anonymous"` +} + +// AccessTokenResponse represents an OAuth2 success response +type AccessTokenResponse struct { + Token string `json:"access_token"` + TokenType string `json:"token_type"` // Bearer + ExpiresIn int `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + RefreshToken string `json:"refresh_token"` + User *models.User `json:"user"` + ProviderAccessToken string `json:"provider_token,omitempty"` + ProviderRefreshToken string `json:"provider_refresh_token,omitempty"` + WeakPassword *WeakPasswordError `json:"weak_password,omitempty"` +} + +// AsRedirectURL encodes the AccessTokenResponse as a redirect URL that +// includes the access token response data in a URL fragment. +func (r *AccessTokenResponse) AsRedirectURL(redirectURL string, extraParams url.Values) string { + extraParams.Set("access_token", r.Token) + extraParams.Set("token_type", r.TokenType) + extraParams.Set("expires_in", strconv.Itoa(r.ExpiresIn)) + extraParams.Set("expires_at", strconv.FormatInt(r.ExpiresAt, 10)) + extraParams.Set("refresh_token", r.RefreshToken) + + return redirectURL + "#" + extraParams.Encode() +} + +// PasswordGrantParams are the parameters the ResourceOwnerPasswordGrant method accepts +type PasswordGrantParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + Password string `json:"password"` +} + +// PKCEGrantParams are the parameters the PKCEGrant method accepts +type PKCEGrantParams struct { + AuthCode string `json:"auth_code"` + CodeVerifier string `json:"code_verifier"` +} + +const useCookieHeader = "x-use-cookie" +const InvalidLoginMessage = "Invalid login credentials" + +// Token is the endpoint for OAuth access token requests +func (a *API) Token(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + grantType := r.FormValue("grant_type") + + switch grantType { + case "password": + return a.ResourceOwnerPasswordGrant(ctx, w, r) + case "refresh_token": + return a.RefreshTokenGrant(ctx, w, r) + case "id_token": + return a.IdTokenGrant(ctx, w, r) + case "pkce": + return a.PKCE(ctx, w, r) + case "web3": + return a.Web3Grant(ctx, w, r) + default: + return badRequestError(apierrors.ErrorCodeInvalidCredentials, "unsupported_grant_type") + } +} + +// ResourceOwnerPasswordGrant implements the password grant type flow +func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + + params := &PasswordGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + aud := a.requestAud(ctx, r) + config := a.config + + if params.Email != "" && params.Phone != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") + } + var user *models.User + var grantParams models.GrantParams + var provider string + var err error + + grantParams.FillGrantParams(r) + + if params.Email != "" { + provider = "email" + if !config.External.Email.Enabled { + return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + provider = "phone" + if !config.External.Phone.Enabled { + return unprocessableEntityError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + } + params.Phone = formatPhoneNumber(params.Phone) + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } else { + return badRequestError(apierrors.ErrorCodeValidationFailed, "missing email or phone") + } + + if err != nil { + if models.IsNotFoundError(err) { + return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + return internalServerError("Database error querying schema").WithInternalError(err) + } + + if !user.HasPassword() { + return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + + if user.IsBanned() { + return badRequestError(apierrors.ErrorCodeUserBanned, "User is banned") + } + + isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return err + } + + var weakPasswordError *WeakPasswordError + if isValidPassword { + if err := a.checkPasswordStrength(ctx, params.Password); err != nil { + if wpe, ok := err.(*WeakPasswordError); ok { + weakPasswordError = wpe + } else { + observability.GetLogEntry(r).Entry.WithError(err).Warn("Password strength check on sign-in failed") + } + } + + if shouldReEncrypt { + if err := user.SetPassword(ctx, params.Password, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + // directly change this in the database without + // calling user.UpdatePassword() because this + // is not a password change, just encryption + // change in the database + if err := db.UpdateOnly(user, "encrypted_password"); err != nil { + return err + } + } + } + + if config.Hook.PasswordVerificationAttempt.Enabled { + input := hooks.PasswordVerificationAttemptInput{ + UserID: user.ID, + Valid: isValidPassword, + } + output := hooks.PasswordVerificationAttemptOutput{} + if err := a.invokeHook(nil, r, &input, &output); err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if output.Message == "" { + output.Message = hooks.DefaultPasswordHookRejectionMessage + } + if output.ShouldLogoutUser { + if err := models.Logout(a.db, user.ID); err != nil { + return err + } + } + return badRequestError(apierrors.ErrorCodeInvalidCredentials, output.Message) + } + } + if !isValidPassword { + return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + + if params.Email != "" && !user.IsConfirmed() && !a.config.Mailer.AllowUnverifiedEmailSignIns { + return badRequestError(apierrors.ErrorCodeEmailNotConfirmed, "Email not confirmed") + } else if params.Phone != "" && !user.IsPhoneConfirmed() { + return badRequestError(apierrors.ErrorCodePhoneNotConfirmed, "Phone not confirmed") + } + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": provider, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) + if terr != nil { + return terr + } + + return nil + }) + if err != nil { + return err + } + + token.WeakPassword = weakPasswordError + + metering.RecordLogin("password", user.ID) + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + var grantParams models.GrantParams + + // There is a slight problem with this as it will pick-up the + // User-Agent and IP addresses from the server if used on the server + // side. Currently there's no mechanism to distinguish, but the server + // can be told to at least propagate the User-Agent header. + grantParams.FillGrantParams(r) + + params := &PKCEGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.AuthCode == "" || params.CodeVerifier == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") + } + + flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) + // Sanity check in case user ID was not set properly + if models.IsNotFoundError(err) || flowState.UserID == nil { + return notFoundError(apierrors.ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") + } else if err != nil { + return err + } + if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { + return unprocessableEntityError(apierrors.ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") + } + + user, err := models.FindUserByID(db, *flowState.UserID) + if err != nil { + return err + } + if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { + return badRequestError(apierrors.ErrorCodeBadCodeVerifier, err.Error()) + } + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + authMethod, err := models.ParseAuthenticationMethod(flowState.AuthenticationMethod) + if err != nil { + return err + } + if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider_type": flowState.ProviderType, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, authMethod, grantParams) + if terr != nil { + // error type is already handled in issueRefreshToken + return terr + } + token.ProviderAccessToken = flowState.ProviderAccessToken + // Because not all providers give out a refresh token + // See corresponding OAuth2 spec: + if flowState.ProviderRefreshToken != "" { + token.ProviderRefreshToken = flowState.ProviderRefreshToken + } + if terr = tx.Destroy(flowState); terr != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { + config := a.config + if sessionId == nil { + return "", 0, internalServerError("Session is required to issue access token") + } + sid := sessionId.String() + session, terr := models.FindSessionByID(tx, *sessionId, false) + if terr != nil { + return "", 0, terr + } + aal, amr, terr := session.CalculateAALAndAMR(user) + if terr != nil { + return "", 0, terr + } + + issuedAt := time.Now().UTC() + expiresAt := issuedAt.Add(time.Second * time.Duration(config.JWT.Exp)) + + claims := &hooks.AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: user.ID.String(), + Audience: jwt.ClaimStrings{user.Aud}, + IssuedAt: jwt.NewNumericDate(issuedAt), + ExpiresAt: jwt.NewNumericDate(expiresAt), + Issuer: config.JWT.Issuer, + }, + Email: user.GetEmail(), + Phone: user.GetPhone(), + AppMetaData: user.AppMetaData, + UserMetaData: user.UserMetaData, + Role: user.Role, + SessionId: sid, + AuthenticatorAssuranceLevel: aal.String(), + AuthenticationMethodReference: amr, + IsAnonymous: user.IsAnonymous, + } + + var gotrueClaims jwt.Claims = claims + if config.Hook.CustomAccessToken.Enabled { + input := hooks.CustomAccessTokenInput{ + UserID: user.ID, + Claims: claims, + AuthenticationMethod: authenticationMethod.String(), + } + + output := hooks.CustomAccessTokenOutput{} + + err := a.invokeHook(tx, r, &input, &output) + if err != nil { + return "", 0, err + } + gotrueClaims = jwt.MapClaims(output.Claims) + } + + signed, err := signJwt(&config.JWT, gotrueClaims) + if err != nil { + return "", 0, err + } + + return signed, expiresAt.Unix(), nil +} + +func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { + config := a.config + + now := time.Now() + user.LastSignInAt = &now + + var tokenString string + var expiresAt int64 + var refreshToken *models.RefreshToken + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + + refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams) + if terr != nil { + return internalServerError("Database error granting user").WithInternalError(terr) + } + + terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod) + if terr != nil { + return terr + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, refreshToken.SessionId, authenticationMethod) + if terr != nil { + // Account for Hook Error + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + + return &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: refreshToken.Token, + User: user, + }, nil +} + +func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { + ctx := r.Context() + config := a.config + var tokenString string + var expiresAt int64 + var refreshToken *models.RefreshToken + currentClaims := getClaims(ctx) + sessionId, err := uuid.FromString(currentClaims.SessionId) + if err != nil { + return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) + } + + err = tx.Transaction(func(tx *storage.Connection) error { + if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { + return terr + } + session, terr := models.FindSessionByID(tx, sessionId, false) + if terr != nil { + return terr + } + currentToken, terr := models.FindTokenBySessionID(tx, &session.ID) + if terr != nil { + return terr + } + if err := tx.Load(user, "Identities"); err != nil { + return err + } + // Swap to ensure current token is the latest one + refreshToken, terr = models.GrantRefreshTokenSwap(r, tx, user, currentToken) + if terr != nil { + return terr + } + aal, _, terr := session.CalculateAALAndAMR(user) + if terr != nil { + return terr + } + + if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil { + return err + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, &session.ID, authenticationMethod) + if terr != nil { + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + return &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: refreshToken.Token, + User: user, + }, nil +} + +func validateTokenClaims(outputClaims map[string]interface{}) error { + schemaLoader := gojsonschema.NewStringLoader(hooks.MinimumViableTokenSchema) + + documentLoader := gojsonschema.NewGoLoader(outputClaims) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errorMessages string + + for _, desc := range result.Errors() { + errorMessages += fmt.Sprintf("- %s\n", desc) + fmt.Printf("- %s\n", desc) + } + return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages) + + } + + return nil +} diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go new file mode 100644 index 000000000..650d1d96d --- /dev/null +++ b/internal/api/token_oidc.go @@ -0,0 +1,254 @@ +package api + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +// IdTokenGrantParams are the parameters the IdTokenGrant method accepts +type IdTokenGrantParams struct { + IdToken string `json:"id_token"` + AccessToken string `json:"access_token"` + Nonce string `json:"nonce"` + Provider string `json:"provider"` + ClientID string `json:"client_id"` + Issuer string `json:"issuer"` +} + +func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, error) { + log := observability.GetLogEntry(r).Entry + + var cfg *conf.OAuthProviderConfiguration + var issuer string + var providerType string + var acceptableClientIDs []string + + switch true { + case p.Provider == "apple" || p.Issuer == provider.IssuerApple: + cfg = &config.External.Apple + providerType = "apple" + issuer = provider.IssuerApple + acceptableClientIDs = append(acceptableClientIDs, config.External.Apple.ClientID...) + + if config.External.IosBundleId != "" { + acceptableClientIDs = append(acceptableClientIDs, config.External.IosBundleId) + } + + case p.Provider == "google" || p.Issuer == provider.IssuerGoogle: + cfg = &config.External.Google + providerType = "google" + issuer = provider.IssuerGoogle + acceptableClientIDs = append(acceptableClientIDs, config.External.Google.ClientID...) + + case p.Provider == "azure" || provider.IsAzureIssuer(p.Issuer): + issuer = p.Issuer + if issuer == "" || !provider.IsAzureIssuer(issuer) { + detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) + if err != nil { + return nil, false, "", nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + } + issuer = detectedIssuer + } + cfg = &config.External.Azure + providerType = "azure" + acceptableClientIDs = append(acceptableClientIDs, config.External.Azure.ClientID...) + + case p.Provider == "facebook" || p.Issuer == provider.IssuerFacebook: + cfg = &config.External.Facebook + providerType = "facebook" + issuer = provider.IssuerFacebook + acceptableClientIDs = append(acceptableClientIDs, config.External.Facebook.ClientID...) + + case p.Provider == "keycloak" || (config.External.Keycloak.Enabled && config.External.Keycloak.URL != "" && p.Issuer == config.External.Keycloak.URL): + cfg = &config.External.Keycloak + providerType = "keycloak" + issuer = config.External.Keycloak.URL + acceptableClientIDs = append(acceptableClientIDs, config.External.Keycloak.ClientID...) + + case p.Provider == "kakao" || p.Issuer == provider.IssuerKakao: + cfg = &config.External.Kakao + providerType = "kakao" + issuer = provider.IssuerKakao + acceptableClientIDs = append(acceptableClientIDs, config.External.Kakao.ClientID...) + + case p.Provider == "vercel_marketplace" || p.Issuer == provider.IssuerVercelMarketplace: + cfg = &config.External.VercelMarketplace + providerType = "vercel_marketplace" + issuer = provider.IssuerVercelMarketplace + acceptableClientIDs = append(acceptableClientIDs, config.External.VercelMarketplace.ClientID...) + + default: + log.WithField("issuer", p.Issuer).WithField("client_id", p.ClientID).Warn("Use of POST /token with arbitrary issuer and client_id is deprecated for security reasons. Please switch to using the API with provider only!") + + allowed := false + for _, allowedIssuer := range config.External.AllowedIdTokenIssuers { + if p.Issuer == allowedIssuer { + allowed = true + providerType = allowedIssuer + acceptableClientIDs = []string{p.ClientID} + issuer = allowedIssuer + break + } + } + + if !allowed { + return nil, false, "", nil, badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + } + + cfg = &conf.OAuthProviderConfiguration{ + Enabled: true, + SkipNonceCheck: false, + } + } + + if !cfg.Enabled { + return nil, false, "", nil, badRequestError(apierrors.ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + } + + oidcProvider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, false, "", nil, err + } + + return oidcProvider, cfg.SkipNonceCheck, providerType, acceptableClientIDs, nil +} + +// IdTokenGrant implements the id_token grant type flow +func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + log := observability.GetLogEntry(r).Entry + + db := a.db.WithContext(ctx) + config := a.config + + params := &IdTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.IdToken == "" { + return oauthError("invalid request", "id_token required") + } + + if params.Provider == "" && (params.ClientID == "" || params.Issuer == "") { + return oauthError("invalid request", "provider or client_id and issuer required") + } + + oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r) + if err != nil { + return err + } + + idToken, userData, err := provider.ParseIDToken(ctx, oidcProvider, nil, params.IdToken, provider.ParseIDTokenOptions{ + SkipAccessTokenCheck: params.AccessToken == "", + AccessToken: params.AccessToken, + }) + if err != nil { + return oauthError("invalid request", "Bad ID token").WithInternalError(err) + } + + userData.Metadata.EmailVerified = false + for _, email := range userData.Emails { + if email.Primary { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + break + } else { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + } + } + + if idToken.Subject == "" { + return oauthError("invalid request", "Missing sub claim in id_token") + } + + correctAudience := false + for _, clientID := range acceptableClientIDs { + if clientID == "" { + continue + } + + for _, aud := range idToken.Audience { + if aud == clientID { + correctAudience = true + break + } + } + + if correctAudience { + break + } + } + + if !correctAudience { + return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) + } + + if !skipNonceCheck { + tokenHasNonce := idToken.Nonce != "" + paramsHasNonce := params.Nonce != "" + + if tokenHasNonce != paramsHasNonce { + return oauthError("invalid request", "Passed nonce and nonce in id_token should either both exist or not.") + } else if tokenHasNonce && paramsHasNonce { + // verify nonce to mitigate replay attacks + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(params.Nonce))) + if hash != idToken.Nonce { + return oauthError("invalid nonce", "Nonces mismatch") + } + } + } + + if params.AccessToken == "" { + if idToken.AccessTokenHash != "" { + log.Warn("ID token has a at_hash claim, but no access_token parameter was provided. In future versions, access_token will be mandatory as it's security best practice.") + } + } else { + if idToken.AccessTokenHash == "" { + log.Info("ID token does not have a at_hash claim, access_token parameter is unused.") + } + } + + var token *AccessTokenResponse + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + if err := db.Transaction(func(tx *storage.Connection) error { + var user *models.User + var terr error + + user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType) + if terr != nil { + return terr + } + + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) + if terr != nil { + return terr + } + + return nil + }); err != nil { + switch err.(type) { + case *storage.CommitWithError: + return err + case *HTTPError: + return err + default: + return oauthError("server_error", "Internal Server Error").WithInternalError(err) + } + } + + return sendJSON(w, http.StatusOK, token) +} diff --git a/internal/api/token_oidc_test.go b/internal/api/token_oidc_test.go new file mode 100644 index 000000000..1eab99ebd --- /dev/null +++ b/internal/api/token_oidc_test.go @@ -0,0 +1,69 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" +) + +type TokenOIDCTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestTokenOIDC(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &TokenOIDCTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func SetupTestOIDCProvider(ts *TokenOIDCTestSuite) *httptest.Server { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"issuer":"` + server.URL + `","authorization_endpoint":"` + server.URL + `/authorize","token_endpoint":"` + server.URL + `/token","jwks_uri":"` + server.URL + `/jwks"}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + return server +} + +func (ts *TokenOIDCTestSuite) TestGetProvider() { + server := SetupTestOIDCProvider(ts) + defer server.Close() + + params := &IdTokenGrantParams{ + IdToken: "test-id-token", + AccessToken: "test-access-token", + Nonce: "test-nonce", + Provider: server.URL, + ClientID: "test-client-id", + Issuer: server.URL, + } + + ts.Config.External.AllowedIdTokenIssuers = []string{server.URL} + + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + oidcProvider, skipNonceCheck, providerType, acceptableClientIds, err := params.getProvider(context.Background(), ts.Config, req) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), oidcProvider) + require.False(ts.T(), skipNonceCheck) + require.Equal(ts.T(), params.Provider, providerType) + require.NotEmpty(ts.T(), acceptableClientIds) +} diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go new file mode 100644 index 000000000..7cb5c66dc --- /dev/null +++ b/internal/api/token_refresh.go @@ -0,0 +1,275 @@ +package api + +import ( + "context" + mathRand "math/rand" + "net/http" + "time" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const retryLoopDuration = 5.0 + +// RefreshTokenGrantParams are the parameters the RefreshTokenGrant method accepts +type RefreshTokenGrantParams struct { + RefreshToken string `json:"refresh_token"` +} + +// RefreshTokenGrant implements the refresh_token grant type flow +func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + config := a.config + + params := &RefreshTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.RefreshToken == "" { + return oauthError("invalid_request", "refresh_token required") + } + + // A 5 second retry loop is used to make sure that refresh token + // requests do not waste database connections waiting for each other. + // Instead of waiting at the database level, they're waiting at the API + // level instead and retry to refresh the locked row every 10-30 + // milliseconds. + retryStart := a.Now() + retry := true + + for retry && time.Since(retryStart).Seconds() < retryLoopDuration { + retry = false + + user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) + if err != nil { + if models.IsNotFoundError(err) { + return badRequestError(apierrors.ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") + } + return internalServerError(err.Error()) + } + + if user.IsBanned() { + return badRequestError(apierrors.ErrorCodeUserBanned, "Invalid Refresh Token: User Banned") + } + + if session == nil { + // a refresh token won't have a session if it's created prior to the sessions table introduced + if err := db.Destroy(token); err != nil { + return internalServerError("Error deleting refresh token with missing session").WithInternalError(err) + } + return badRequestError(apierrors.ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") + } + + result := session.CheckValidity(retryStart, &token.UpdatedAt, config.Sessions.Timebox, config.Sessions.InactivityTimeout) + + switch result { + case models.SessionValid: + // do nothing + + case models.SessionTimedOut: + return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Inactivity)") + + default: + return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired") + } + + // Basic checks above passed, now we need to serialize access + // to the session in a transaction so that there's no + // concurrent modification. In the event that the refresh + // token's row or session is locked, the transaction is closed + // and the whole process will be retried a bit later so that + // the connection pool does not get exhausted. + + var tokenString string + var expiresAt int64 + var newTokenResponse *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + user, token, session, terr := models.FindUserWithRefreshToken(tx, params.RefreshToken, true /* forUpdate */) + if terr != nil { + if models.IsNotFoundError(terr) { + // because forUpdate was set, and the + // previous check outside the + // transaction found a refresh token + // and session, but now we're getting a + // IsNotFoundError, this means that the + // refresh token row and session are + // probably locked so we need to retry + // in a few milliseconds. + retry = true + return terr + } + return internalServerError(terr.Error()) + } + + if a.config.Sessions.SinglePerUser { + sessions, terr := models.FindAllSessionsForUser(tx, user.ID, true /* forUpdate */) + if models.IsNotFoundError(terr) { + // because forUpdate was set, and the + // previous check outside the + // transaction found a user and + // session, but now we're getting a + // IsNotFoundError, this means that the + // user is locked and we need to retry + // in a few milliseconds + retry = true + return terr + } else if terr != nil { + return internalServerError(terr.Error()) + } + + sessionTag := session.DetermineTag(config.Sessions.Tags) + + // go through all sessions of the user and + // check if the current session is the user's + // most recently refreshed valid session + for _, s := range sessions { + if s.ID == session.ID { + // current session, skip it + continue + } + + if s.CheckValidity(retryStart, nil, config.Sessions.Timebox, config.Sessions.InactivityTimeout) != models.SessionValid { + // session is not valid so it + // can't be regarded as active + // on the user + continue + } + + if s.DetermineTag(config.Sessions.Tags) != sessionTag { + // if tags are specified, + // ignore sessions with a + // mismatching tag + continue + } + + // since token is not the refresh token + // of s, we can't use it's UpdatedAt + // time to compare! + if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) { + // session is not the most + // recently active one + return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") + } + } + + // this session is the user's active session + } + + // refresh token row and session are locked at this + // point, cannot be concurrently refreshed + + var issuedToken *models.RefreshToken + + if token.Revoked { + activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) + if terr != nil && !models.IsNotFoundError(terr) { + return internalServerError(terr.Error()) + } + + if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { + // Token was revoked, but it's the + // parent of the currently active one. + // This indicates that the client was + // not able to store the result when it + // refreshed token. This case is + // allowed, provided we return back the + // active refresh token instead of + // creating a new one. + issuedToken = activeRefreshToken + } else { + // For a revoked refresh token to be reused, it + // has to fall within the reuse interval. + reuseUntil := token.UpdatedAt.Add( + time.Second * time.Duration(config.Security.RefreshTokenReuseInterval)) + + if a.Now().After(reuseUntil) { + // not OK to reuse this token + if config.Security.RefreshTokenRotationEnabled { + // Revoke all tokens in token family + if err := models.RevokeTokenFamily(tx, token); err != nil { + return internalServerError(err.Error()) + } + } + + return storage.NewCommitWithError(badRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) + } + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { + return terr + } + + if issuedToken == nil { + newToken, terr := models.GrantRefreshTokenSwap(r, tx, user, token) + if terr != nil { + return terr + } + + issuedToken = newToken + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, issuedToken.SessionId, models.TokenRefresh) + if terr != nil { + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + + refreshedAt := a.Now() + session.RefreshedAt = &refreshedAt + + userAgent := r.Header.Get("User-Agent") + if userAgent != "" { + session.UserAgent = &userAgent + } else { + session.UserAgent = nil + } + + ipAddress := utilities.GetIPAddress(r) + if ipAddress != "" { + session.IP = &ipAddress + } else { + session.IP = nil + } + + if terr := session.UpdateOnlyRefreshInfo(tx); terr != nil { + return internalServerError("failed to update session information").WithInternalError(terr) + } + + newTokenResponse = &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: issuedToken.Token, + User: user, + } + + return nil + }) + if err != nil { + if retry && models.IsNotFoundError(err) { + // refresh token and session row were likely locked, so + // we need to wait a moment before retrying the whole + // process anew + time.Sleep(time.Duration(10+mathRand.Intn(20)) * time.Millisecond) // #nosec + continue + } else { + return err + } + } + metering.RecordLogin("token", user.ID) + return sendJSON(w, http.StatusOK, newTokenResponse) + } + + return conflictError("Too many concurrent token refresh requests on the same session or refresh token") +} diff --git a/internal/api/token_test.go b/internal/api/token_test.go new file mode 100644 index 000000000..540e28272 --- /dev/null +++ b/internal/api/token_test.go @@ -0,0 +1,858 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type TokenTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + + RefreshToken *models.RefreshToken + User *models.User +} + +func TestToken(t *testing.T) { + os.Setenv("GOTRUE_RATE_LIMIT_HEADER", "My-Custom-Header") + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &TokenTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *TokenTestSuite) SetupTest() { + ts.RefreshToken = nil + models.TruncateAll(ts.API.db) + + // Create user & refresh token + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + t := time.Now() + u.EmailConfirmedAt = &t + u.BannedUntil = nil + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + ts.User = u + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err, "Error creating refresh token") + ts.Config.Hook.CustomAccessToken.Enabled = false + +} + +func (ts *TokenTestSuite) TestSessionTimebox() { + timebox := 10 * time.Second + + ts.API.config.Sessions.Timebox = &timebox + ts.API.overrideTime = func() time.Time { + return time.Now().Add(timebox).Add(time.Second) + } + + defer func() { + ts.API.overrideTime = nil + ts.API.config.Sessions.Timebox = nil + }() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), apierrors.ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.Message) +} + +func (ts *TokenTestSuite) TestSessionInactivityTimeout() { + inactivityTimeout := 10 * time.Second + + ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout + ts.API.overrideTime = func() time.Time { + return time.Now().Add(inactivityTimeout).Add(time.Second) + } + + defer func() { + ts.API.config.Sessions.InactivityTimeout = nil + ts.API.overrideTime = nil + }() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), apierrors.ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.Message) +} + +func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() { + var buffer bytes.Buffer + + // first refresh + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var firstResult struct { + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.NotEmpty(ts.T(), firstResult.RefreshToken) + + // pretend that the browser wasn't able to save the firstResult, + // run again with the first refresh token + buffer = bytes.Buffer{} + + // second refresh with the reused refresh token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var secondResult struct { + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&secondResult)) + assert.NotEmpty(ts.T(), secondResult.RefreshToken) + + // new refresh token is not being issued but the active one from + // the first refresh that failed to save is stored + assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken) +} + +func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() { + ts.API.config.Sessions.SinglePerUser = true + defer func() { + ts.API.config.Sessions.SinglePerUser = false + }() + + firstRefreshToken := ts.RefreshToken + + // just in case to give some delay between first and second session creation + time.Sleep(10 * time.Millisecond) + + secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{}) + + require.NoError(ts.T(), err) + + require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId) + require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": firstRefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), apierrors.ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.Message) +} + +func (ts *TokenTestSuite) TestRateLimitTokenRefresh() { + var buffer bytes.Buffer + req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + + // It rate limits after 30 requests + for i := 0; i < 30; i++ { + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + } + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It ignores X-Forwarded-For by default + req.Header.Set("X-Forwarded-For", "1.1.1.1") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It doesn't rate limit a new value for the limited header + req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "5.6.7.8") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() { + u := ts.createBannedUser() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { + authCode := "1234563" + codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2" + invalidAuthCode := authCode + "123" + invalidVerifier := codeVerifier + "123" + codeChallenge := sha256.Sum256([]byte(codeVerifier)) + challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil) + flowState.AuthCode = authCode + require.NoError(ts.T(), ts.API.db.Create(flowState)) + cases := []struct { + desc string + authCode string + codeVerifier string + grantType string + expectedHTTPCode int + }{ + { + desc: "Invalid Authcode", + authCode: invalidAuthCode, + codeVerifier: codeVerifier, + }, + { + desc: "Invalid code verifier", + authCode: authCode, + codeVerifier: invalidVerifier, + }, + { + desc: "Invalid auth code and verifier", + authCode: invalidAuthCode, + codeVerifier: invalidVerifier, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": v.codeVerifier, + "auth_code": v.authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) + }) + } +} + +func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() { + _ = ts.createBannedUser() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() { + originalSecurity := ts.API.config.Security + + ts.API.config.Security.RefreshTokenRotationEnabled = true + ts.API.config.Security.RefreshTokenReuseInterval = 0 + + defer func() { + ts.API.config.Security = originalSecurity + }() + + refreshTokens := []string{ + ts.RefreshToken.Token, + } + + for i := 0; i < 3; i += 1 { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[len(refreshTokens)-1], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response struct { + RefreshToken string `json:"refresh_token"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + refreshTokens = append(refreshTokens, response.RefreshToken) + } + + // ensure that the 4 refresh tokens are setup correctly + for i, refreshToken := range refreshTokens { + _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + require.NoError(ts.T(), err) + + if i == len(refreshTokens)-1 { + require.False(ts.T(), token.Revoked) + } else { + require.True(ts.T(), token.Revoked) + } + } + + // try to reuse the first (earliest) refresh token which should trigger the family revocation logic + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[0], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var response struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + require.Equal(ts.T(), apierrors.ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode) + require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message) + + // ensure that the refresh tokens are marked as revoked in the database + for _, refreshToken := range refreshTokens { + _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + require.NoError(ts.T(), err) + + require.True(ts.T(), token.Revoked) + } + + // finally ensure that none of the refresh tokens can be reused any + // more, starting with the previously valid one + for i := len(refreshTokens) - 1; i >= 0; i -= 1 { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[i], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i) + + var response struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + require.Equal(ts.T(), apierrors.ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode, "For refresh token %d", i) + require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message, "For refresh token %d", i) + } +} + +func (ts *TokenTestSuite) createBannedUser() *models.User { + u, err := models.NewUser("", "banned@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + t := time.Now() + u.EmailConfirmedAt = &t + t = t.Add(24 * time.Hour) + u.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user") + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err, "Error creating refresh token") + + return u +} + +func (ts *TokenTestSuite) TestTokenRefreshWithExpiredSession() { + var err error + + now := time.Now().UTC().Add(-1 * time.Second) + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{ + SessionNotAfter: &now, + }) + require.NoError(ts.T(), err, "Error creating refresh token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() { + var err error + + now := time.Now().UTC().Add(1 * time.Second) + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{ + SessionNotAfter: &now, + }) + require.NoError(ts.T(), err, "Error creating refresh token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestMagicLinkPKCESignIn() { + var buffer bytes.Buffer + // Send OTP + codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2" + codeChallenge := sha256.Sum256([]byte(codeVerifier)) + challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: challenge, + })) + req = httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Verify OTP + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", "magiclink", u.RecoveryToken) + req = httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + authCode := f.Get("code") + assert.NotEmpty(ts.T(), authCode) + // Extract token and sign in + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + verifyResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&verifyResp)) + require.NotEmpty(ts.T(), verifyResp.Token) + +} + +func (ts *TokenTestSuite) TestPasswordVerificationHook() { + type verificationHookTestcase struct { + desc string + uri string + hookFunctionSQL string + expectedCode int + } + cases := []verificationHookTestcase{ + { + desc: "Default success", + uri: "pg-functions://postgres/auth/password_verification_hook", + hookFunctionSQL: ` + create or replace function password_verification_hook(input jsonb) + returns jsonb as $$ + begin + return jsonb_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + expectedCode: http.StatusOK, + }, { + desc: "Reject- Enabled", + uri: "pg-functions://postgres/auth/password_verification_hook_reject", + hookFunctionSQL: ` + create or replace function password_verification_hook_reject(input jsonb) + returns jsonb as $$ + begin + return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!'); + end; $$ language plpgsql;`, + expectedCode: http.StatusBadRequest, + }, + } + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.PasswordVerificationAttempt.Enabled = true + ts.Config.Hook.PasswordVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), c.expectedCode, w.Code) + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.PasswordVerificationAttempt.HookName) + require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec()) + // Reset so it doesn't affect other tests + ts.Config.Hook.PasswordVerificationAttempt.Enabled = false + + }) + } + +} + +func (ts *TokenTestSuite) TestCustomAccessToken() { + type customAccessTokenTestcase struct { + desc string + uri string + hookFunctionSQL string + expectedClaims map[string]interface{} + shouldError bool + } + cases := []customAccessTokenTestcase{ + { + desc: "Add a new claim", + uri: "pg-functions://postgres/auth/custom_access_token_add_claim", + hookFunctionSQL: ` create or replace function custom_access_token_add_claim(input jsonb) returns jsonb as $$ declare result jsonb; begin if jsonb_typeof(jsonb_object_field(input, 'claims')) is null then result := jsonb_build_object('error', jsonb_build_object('http_code', 400, 'message', 'Input does not contain claims field')); return result; end if; + input := jsonb_set(input, '{claims,newclaim}', '"newvalue"', true); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + expectedClaims: map[string]interface{}{ + "newclaim": "newvalue", + }, + }, { + desc: "Delete the Role claim", + uri: "pg-functions://postgres/auth/custom_access_token_delete_claim", + hookFunctionSQL: ` +create or replace function custom_access_token_delete_claim(input jsonb) +returns jsonb as $$ +declare + result jsonb; +begin + input := jsonb_set(input, '{claims}', (input->'claims') - 'role'); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + expectedClaims: map[string]interface{}{}, + shouldError: true, + }, { + desc: "Delete a non-required claim (UserMetadata)", + uri: "pg-functions://postgres/auth/custom_access_token_delete_usermetadata", + hookFunctionSQL: ` +create or replace function custom_access_token_delete_usermetadata(input jsonb) +returns jsonb as $$ +declare + result jsonb; +begin + input := jsonb_set(input, '{claims}', (input->'claims') - 'user_metadata'); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + // Not used + expectedClaims: map[string]interface{}{ + "user_metadata": nil, + }, + shouldError: false, + }, + } + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = c.uri + require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + var buffer bytes.Buffer + require.NoError(t, json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + var tokenResponse struct { + AccessToken string `json:"access_token"` + } + require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&tokenResponse)) + if c.shouldError { + require.Equal(t, http.StatusInternalServerError, w.Code) + } else { + parts := strings.Split(tokenResponse.AccessToken, ".") + require.Equal(t, 3, len(parts), "Token should have 3 parts") + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var responseClaims map[string]interface{} + require.NoError(t, json.Unmarshal(payload, &responseClaims)) + + for key, expectedValue := range c.expectedClaims { + if expectedValue == nil { + // Since c.shouldError is false here, we only need to check if the claim should be removed + _, exists := responseClaims[key] + assert.False(t, exists, "Claim should be removed") + } else { + assert.Equal(t, expectedValue, responseClaims[key]) + } + } + } + + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false + }) + } +} + +func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() { + + companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil) + t := time.Now() + companyUser.EmailConfirmedAt = &t + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(companyUser), "Error saving new test user") + + type allowSelectAuthMethodsTestcase struct { + desc string + uri string + email string + expectedError string + expectedStatus int + } + + // Common hook function SQL definition + hookFunctionSQL := ` +create or replace function auth.custom_access_token(event jsonb) returns jsonb language plpgsql as $$ +declare + email_claim text; + authentication_method text; +begin + email_claim := event->'claims'->>'email'; + authentication_method := event->>'authentication_method'; + + if authentication_method = 'password' and email_claim not like '%@company.com' then + return jsonb_build_object( + 'error', jsonb_build_object( + 'http_code', 403, + 'message', 'only members on company.com can access with password authentication' + ) + ); + end if; + + return event; +end; +$$;` + + cases := []allowSelectAuthMethodsTestcase{ + { + desc: "Error for non-protected domain with password authentication", + uri: "pg-functions://postgres/auth/custom_access_token", + email: "test@example.com", + expectedError: "only members on company.com can access with password authentication", + expectedStatus: http.StatusForbidden, + }, + { + desc: "Allow access for protected domain with password authentication", + uri: "pg-functions://postgres/auth/custom_access_token", + email: companyUser.Email.String(), + expectedError: "", + expectedStatus: http.StatusOK, + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + // Enable and set up the custom access token hook + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = c.uri + require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + + // Execute the common hook function SQL + err := ts.API.db.RawQuery(hookFunctionSQL).Exec() + require.NoError(t, err) + + var buffer bytes.Buffer + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": c.email, + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(t, c.expectedStatus, w.Code, "Unexpected HTTP status code") + if c.expectedError != "" { + require.Contains(t, w.Body.String(), c.expectedError, "Expected error message not found") + } else { + require.NotContains(t, w.Body.String(), "error", "Unexpected error occurred") + } + + // Delete the function and cleanup + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false + }) + } +} diff --git a/internal/api/user.go b/internal/api/user.go new file mode 100644 index 000000000..815511970 --- /dev/null +++ b/internal/api/user.go @@ -0,0 +1,267 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// UserUpdateParams parameters for updating a user +type UserUpdateParams struct { + Email string `json:"email"` + Password *string `json:"password"` + Nonce string `json:"nonce"` + Data map[string]interface{} `json:"data"` + AppData map[string]interface{} `json:"app_metadata,omitempty"` + Phone string `json:"phone"` + Channel string `json:"channel"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) error { + config := a.config + + var err error + if p.Email != "" { + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + } + + if p.Phone != "" { + if p.Phone, err = validatePhone(p.Phone); err != nil { + return err + } + if p.Channel == "" { + p.Channel = sms_provider.SMSProvider + } + if !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + } + } + + if p.Password != nil { + if err := a.checkPasswordStrength(ctx, *p.Password); err != nil { + return err + } + } + + return nil +} + +// UserGet returns a user +func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + claims := getClaims(ctx) + if claims == nil { + return internalServerError("Could not read claims") + } + + aud := a.requestAud(ctx, r) + audienceFromClaims, _ := claims.GetAudience() + if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Token audience doesn't match request audience") + } + + user := getUser(ctx) + return sendJSON(w, http.StatusOK, user) +} + +// UserUpdate updates fields on a user +func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + aud := a.requestAud(ctx, r) + + params := &UserUpdateParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + user := getUser(ctx) + session := getSession(ctx) + + if err := a.validateUserUpdateParams(ctx, params); err != nil { + return err + } + + if params.AppData != nil && !isAdmin(user, config) { + if !isAdmin(user, config) { + return forbiddenError(apierrors.ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") + } + } + + if user.HasMFAEnabled() && !session.IsAAL2() { + if (params.Password != nil && *params.Password != "") || (params.Email != "" && user.GetEmail() != params.Email) || (params.Phone != "" && user.GetPhone() != params.Phone) { + return httpError(http.StatusUnauthorized, apierrors.ErrorCodeInsufficientAAL, "AAL2 session is required to update email or password when MFA is enabled.") + } + } + + if user.IsAnonymous { + if params.Password != nil && *params.Password != "" { + if params.Email == "" && params.Phone == "" { + return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Updating password of an anonymous user without an email or phone is not allowed") + } + } + } + + if user.IsSSOUser { + updatingForbiddenFields := false + + updatingForbiddenFields = updatingForbiddenFields || (params.Password != nil && *params.Password != "") + updatingForbiddenFields = updatingForbiddenFields || (params.Email != "" && params.Email != user.GetEmail()) + updatingForbiddenFields = updatingForbiddenFields || (params.Phone != "" && params.Phone != user.GetPhone()) + updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") + + if updatingForbiddenFields { + return unprocessableEntityError(apierrors.ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") + } + } + + if params.Email != "" && user.GetEmail() != params.Email { + if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { + return internalServerError("Database error checking email").WithInternalError(err) + } else if duplicateUser != nil { + return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } + } + + if params.Phone != "" && user.GetPhone() != params.Phone { + if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { + return internalServerError("Database error checking phone").WithInternalError(err) + } else if exists { + return unprocessableEntityError(apierrors.ErrorCodePhoneExists, DuplicatePhoneMsg) + } + } + + if params.Password != nil { + if config.Security.UpdatePasswordRequireReauthentication { + now := time.Now() + // we require reauthentication if the user hasn't signed in recently in the current session + if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { + if len(params.Nonce) == 0 { + return badRequestError(apierrors.ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") + } + if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { + return err + } + } + } + + password := *params.Password + if password != "" { + isSamePassword := false + + if user.HasPassword() { + auth, _, err := user.Authenticate(ctx, db, password, config.Security.DBEncryption.DecryptionKeys, false, "") + if err != nil { + return err + } + + isSamePassword = auth + } + + if isSamePassword { + return unprocessableEntityError(apierrors.ErrorCodeSamePassword, "New password should be different from the old password.") + } + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + } + + err := db.Transaction(func(tx *storage.Connection) error { + var terr error + if params.Password != nil { + var sessionID *uuid.UUID + if session != nil { + sessionID = &session.ID + } + + if terr = user.UpdatePassword(tx, sessionID); terr != nil { + return internalServerError("Error during password storage").WithInternalError(terr) + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.UserUpdatePasswordAction, "", nil); terr != nil { + return terr + } + } + + if params.Data != nil { + if terr = user.UpdateUserMetaData(tx, params.Data); terr != nil { + return internalServerError("Error updating user").WithInternalError(terr) + } + } + + if params.AppData != nil { + if terr = user.UpdateAppMetaData(tx, params.AppData); terr != nil { + return internalServerError("Error updating user").WithInternalError(terr) + } + } + + if params.Email != "" && params.Email != user.GetEmail() { + if user.IsAnonymous && config.Mailer.Autoconfirm { + // anonymous users can add an email with automatic confirmation, which is similar to signing up + // permanent users always need to verify their email address when changing it + user.EmailChange = params.Email + if _, terr := a.emailChangeVerify(r, tx, &VerifyParams{ + Type: mailer.EmailChangeVerification, + Email: params.Email, + }, user); terr != nil { + return terr + } + + } else { + flowType := getFlowFromChallenge(params.CodeChallenge) + if isPKCEFlow(flowType) { + _, terr := generateFlowState(tx, models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if terr != nil { + return terr + } + + } + if terr = a.sendEmailChange(r, tx, user, params.Email, flowType); terr != nil { + return terr + } + } + } + + if params.Phone != "" && params.Phone != user.GetPhone() { + if config.Sms.Autoconfirm { + user.PhoneChange = params.Phone + if _, terr := a.smsVerify(r, tx, user, &VerifyParams{ + Type: phoneChangeVerification, + Phone: params.Phone, + }); terr != nil { + return terr + } + } else { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, params.Channel); terr != nil { + return terr + } + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, user) +} diff --git a/internal/api/user_test.go b/internal/api/user_test.go new file mode 100644 index 000000000..ed6c585af --- /dev/null +++ b/internal/api/user_test.go @@ -0,0 +1,558 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type UserTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestUser(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &UserTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *UserTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("123456789", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *UserTestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string { + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.PasswordGrant) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *UserTestSuite) generateAccessTokenAndSession(user *models.User) string { + session, err := models.NewSession(user.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, &session.ID, models.PasswordGrant) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *UserTestSuite) TestUserGet() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err, "Error finding user") + token := ts.generateAccessTokenAndSession(u) + + require.NoError(ts.T(), err, "Error generating access token") + + req := httptest.NewRequest(http.MethodGet, "http://localhost/user", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *UserTestSuite) TestUserUpdateEmail() { + cases := []struct { + desc string + userData map[string]interface{} + isSecureEmailChangeEnabled bool + isMailerAutoconfirmEnabled bool + expectedCode int + }{ + { + desc: "User doesn't have an existing email", + userData: map[string]interface{}{ + "email": "", + "phone": "", + }, + isSecureEmailChangeEnabled: false, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User doesn't have an existing email and double email confirmation required", + userData: map[string]interface{}{ + "email": "", + "phone": "234567890", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User has an existing email", + userData: map[string]interface{}{ + "email": "foo@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: false, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User has an existing email and double email confirmation required", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "Update email with mailer autoconfirm enabled", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: true, + expectedCode: http.StatusOK, + }, + { + desc: "Update email with mailer autoconfirm enabled and anonymous user", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + "is_anonymous": true, + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: true, + expectedCode: http.StatusOK, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"].(string)), "Error setting user email") + require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"].(string)), "Error setting user phone") + if isAnonymous, ok := c.userData["is_anonymous"]; ok { + u.IsAnonymous = isAnonymous.(bool) + } + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user") + + token := ts.generateAccessTokenAndSession(u) + + require.NoError(ts.T(), err, "Error generating access token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "new@example.com", + })) + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.Config.Mailer.SecureEmailChangeEnabled = c.isSecureEmailChangeEnabled + ts.Config.Mailer.Autoconfirm = c.isMailerAutoconfirmEnabled + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedCode, w.Code) + + var data models.User + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + if c.isMailerAutoconfirmEnabled && u.IsAnonymous { + require.Empty(ts.T(), data.EmailChange) + require.Equal(ts.T(), "new@example.com", data.GetEmail()) + require.Len(ts.T(), data.Identities, 1) + } else { + require.Equal(ts.T(), "new@example.com", data.EmailChange) + require.Len(ts.T(), data.Identities, 0) + } + + // remove user after each case + require.NoError(ts.T(), ts.API.db.Destroy(u)) + }) + } + +} +func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + existingUser, err := models.NewUser("22222222", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(existingUser)) + + cases := []struct { + desc string + userData map[string]string + expectedCode int + }{ + { + desc: "New phone number is the same as current phone number", + userData: map[string]string{ + "phone": "123456789", + }, + expectedCode: http.StatusOK, + }, + { + desc: "New phone number exists already", + userData: map[string]string{ + "phone": "22222222", + }, + expectedCode: http.StatusUnprocessableEntity, + }, + { + desc: "New phone number is different from current phone number", + userData: map[string]string{ + "phone": "234567890", + }, + expectedCode: http.StatusOK, + }, + } + + ts.Config.Sms.Autoconfirm = true + + for _, c := range cases { + ts.Run(c.desc, func() { + token := ts.generateAccessTokenAndSession(u) + require.NoError(ts.T(), err, "Error generating access token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "phone": c.userData["phone"], + })) + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedCode, w.Code) + + if c.expectedCode == http.StatusOK { + // check that the user response returned contains the updated phone field + data := &models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), data.GetPhone(), c.userData["phone"]) + } + }) + } + +} + +func (ts *UserTestSuite) TestUserUpdatePassword() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + r, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err) + + r2, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err) + + // create a session and modify it's created_at time to simulate a session that is not recently logged in + notRecentlyLoggedIn, err := models.FindSessionByID(ts.API.db, *r2.SessionId, true) + require.NoError(ts.T(), err) + + // cannot use Update here because Update doesn't removes the created_at field + require.NoError(ts.T(), ts.API.db.RawQuery( + "update "+notRecentlyLoggedIn.TableName()+" set created_at = ? where id = ?", + time.Now().Add(-24*time.Hour), + notRecentlyLoggedIn.ID).Exec(), + ) + + type expected struct { + code int + isAuthenticated bool + } + + var cases = []struct { + desc string + newPassword string + nonce string + requireReauthentication bool + sessionId *uuid.UUID + expected expected + }{ + { + desc: "Need reauthentication because outside of recently logged in window", + newPassword: "newpassword123", + nonce: "", + requireReauthentication: true, + sessionId: ¬RecentlyLoggedIn.ID, + expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + }, + { + desc: "No nonce provided", + newPassword: "newpassword123", + nonce: "", + sessionId: ¬RecentlyLoggedIn.ID, + requireReauthentication: true, + expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + }, + { + desc: "Invalid nonce", + newPassword: "newpassword1234", + nonce: "123456", + sessionId: ¬RecentlyLoggedIn.ID, + requireReauthentication: true, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, + }, + { + desc: "No need reauthentication because recently logged in", + newPassword: "newpassword123", + nonce: "", + requireReauthentication: true, + sessionId: r.SessionId, + expected: expected{code: http.StatusOK, isAuthenticated: true}, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.UpdatePasswordRequireReauthentication = c.requireReauthentication + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"password": c.newPassword, "nonce": c.nonce})) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + token := ts.generateToken(u, c.sessionId) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) + }) + } +} + +func (ts *UserTestSuite) TestUserUpdatePasswordNoReauthenticationRequired() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + type expected struct { + code int + isAuthenticated bool + } + + var cases = []struct { + desc string + newPassword string + nonce string + requireReauthentication bool + expected expected + }{ + { + desc: "Invalid password length", + newPassword: "", + nonce: "", + requireReauthentication: false, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, + }, + + { + desc: "Valid password length", + newPassword: "newpassword", + nonce: "", + requireReauthentication: false, + expected: expected{code: http.StatusOK, isAuthenticated: true}, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.UpdatePasswordRequireReauthentication = c.requireReauthentication + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"password": c.newPassword, "nonce": c.nonce})) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + token := ts.generateAccessTokenAndSession(u) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) + }) + } +} + +func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() { + ts.Config.Security.UpdatePasswordRequireReauthentication = true + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Confirm the test user + now := time.Now() + u.EmailConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + token := ts.generateAccessTokenAndSession(u) + + // request for reauthentication nonce + req := httptest.NewRequest(http.MethodGet, "http://localhost/reauthenticate", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), w.Code, http.StatusOK) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), u.ReauthenticationToken) + require.NotEmpty(ts.T(), u.ReauthenticationSentAt) + + // update reauthentication token to a known token + u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), "123456") + require.NoError(ts.T(), ts.API.db.Update(u)) + + // update password with reauthentication token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "newpass", + "nonce": "123456", + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), w.Code, http.StatusOK) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.True(ts.T(), isAuthenticated) + require.Empty(ts.T(), u.ReauthenticationToken) + require.Nil(ts.T(), u.ReauthenticationSentAt) +} + +func (ts *UserTestSuite) TestUserUpdatePasswordLogoutOtherSessions() { + ts.Config.Security.UpdatePasswordRequireReauthentication = false + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Confirm the test user + now := time.Now() + u.EmailConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + // Login the test user to get first session + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + session1 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session1)) + + // Login test user to get second session + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + session2 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session2)) + + // Update user's password using first session + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "newpass", + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session1.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Attempt to refresh session1 should pass + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": session1.RefreshToken, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Attempt to refresh session2 should fail + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": session2.RefreshToken, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NotEqual(ts.T(), http.StatusOK, w.Code) +} diff --git a/internal/api/verify.go b/internal/api/verify.go new file mode 100644 index 000000000..0d3c09b19 --- /dev/null +++ b/internal/api/verify.go @@ -0,0 +1,750 @@ +package api + +import ( + "context" + "errors" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/fatih/structs" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/crypto" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const ( + smsVerification = "sms" + phoneChangeVerification = "phone_change" + // includes signupVerification and magicLinkVerification +) + +const ( + zeroConfirmation int = iota + singleConfirmation +) + +// Only applicable when SECURE_EMAIL_CHANGE_ENABLED +const singleConfirmationAccepted = "Confirmation link accepted. Please proceed to confirm link sent to the other email" + +// VerifyParams are the parameters the Verify endpoint accepts +type VerifyParams struct { + Type string `json:"type"` + Token string `json:"token"` + TokenHash string `json:"token_hash"` + Email string `json:"email"` + Phone string `json:"phone"` + RedirectTo string `json:"redirect_to"` +} + +func (p *VerifyParams) Validate(r *http.Request, a *API) error { + var err error + if p.Type == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type") + } + switch r.Method { + case http.MethodGet: + if p.Token == "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a token or a token hash") + } + // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) + p.TokenHash = p.Token + case http.MethodPost: + if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash") + } + if p.Token != "" { + if isPhoneOtpVerification(p) { + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + p.TokenHash = crypto.GenerateTokenHash(p.Phone, p.Token) + } else if isEmailOtpVerification(p) { + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) + } + p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) + } else { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") + } + } else if p.TokenHash != "" { + if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { + return badRequestError(apierrors.ErrorCodeValidationFailed, "Only the token_hash and type should be provided") + } + } + default: + return nil + } + return nil +} + +// Verify exchanges a confirmation or recovery token to a refresh token +func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { + params := &VerifyParams{} + switch r.Method { + case http.MethodGet: + params.Token = r.FormValue("token") + params.Type = r.FormValue("type") + params.RedirectTo = utilities.GetReferrer(r, a.config) + if err := params.Validate(r, a); err != nil { + return err + } + return a.verifyGet(w, r, params) + case http.MethodPost: + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if err := params.Validate(r, a); err != nil { + return err + } + return a.verifyPost(w, r, params) + default: + // this should have been handled by Chi + panic("Only GET and POST methods allowed") + } +} + +func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyParams) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var ( + user *models.User + grantParams models.GrantParams + err error + token *AccessTokenResponse + authCode string + rurl string + ) + + grantParams.FillGrantParams(r) + + flowType := models.ImplicitFlow + var authenticationMethod models.AuthenticationMethod + if strings.HasPrefix(params.Token, PKCEPrefix) { + flowType = models.PKCEFlow + authenticationMethod, err = models.ParseAuthenticationMethod(params.Type) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + user, terr = a.verifyTokenHash(tx, params) + if terr != nil { + return terr + } + switch params.Type { + case mail.SignupVerification, mail.InviteVerification: + user, terr = a.signupVerify(r, ctx, tx, user) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, terr = a.recoverVerify(r, tx, user) + case mail.EmailChangeVerification: + user, terr = a.emailChangeVerify(r, tx, params, user) + if user == nil && terr == nil { + // only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP + rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType) + if terr != nil { + return terr + } + return nil + } + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") + } + + if terr != nil { + return terr + } + + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + + // Reload user model from db. + // This is important for refreshing the data in any generated columns like IsAnonymous. + if terr := tx.Reload(user); err != nil { + return terr + } + + if isImplicitFlow(flowType) { + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) + if terr != nil { + return terr + } + + } else if isPKCEFlow(flowType) { + if authCode, terr = issueAuthCode(tx, user, authenticationMethod); terr != nil { + return badRequestError(apierrors.ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) + } + } + return nil + }) + + if err != nil { + var herr *HTTPError + if errors.As(err, &herr) { + rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType) + if err != nil { + return err + } + } + } + if rurl != "" { + http.Redirect(w, r, rurl, http.StatusSeeOther) + return nil + } + rurl = params.RedirectTo + if isImplicitFlow(flowType) && token != nil { + q := url.Values{} + q.Set("type", params.Type) + rurl = token.AsRedirectURL(rurl, q) + } else if isPKCEFlow(flowType) { + rurl, err = a.prepPKCERedirectURL(rurl, authCode) + if err != nil { + return err + } + } + http.Redirect(w, r, rurl, http.StatusSeeOther) + return nil +} + +func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyParams) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var ( + user *models.User + grantParams models.GrantParams + token *AccessTokenResponse + ) + var isSingleConfirmationResponse = false + + grantParams.FillGrantParams(r) + + err := db.Transaction(func(tx *storage.Connection) error { + var terr error + aud := a.requestAud(ctx, r) + + if isUsingTokenHash(params) { + user, terr = a.verifyTokenHash(tx, params) + } else { + user, terr = a.verifyUserAndToken(tx, params, aud) + } + if terr != nil { + return terr + } + + switch params.Type { + case mail.SignupVerification, mail.InviteVerification: + user, terr = a.signupVerify(r, ctx, tx, user) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, terr = a.recoverVerify(r, tx, user) + case mail.EmailChangeVerification: + user, terr = a.emailChangeVerify(r, tx, params, user) + if user == nil && terr == nil { + isSingleConfirmationResponse = true + return nil + } + case smsVerification, phoneChangeVerification: + user, terr = a.smsVerify(r, tx, user, params) + default: + return badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") + } + + if terr != nil { + return terr + } + + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + + // Reload user model from db. + // This is important for refreshing the data in any generated columns like IsAnonymous. + if terr := tx.Reload(user); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + if isSingleConfirmationResponse { + return sendJSON(w, http.StatusOK, map[string]string{ + "msg": singleConfirmationAccepted, + "code": strconv.Itoa(http.StatusOK), + }) + } + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.Connection, user *models.User) (*models.User, error) { + config := a.config + + shouldUpdatePassword := false + if !user.HasPassword() && user.InvitedAt != nil { + // sign them up with temporary password, and require application + // to present the user with a password set form + password, err := password.Generate(64, 10, 0, false, true) + if err != nil { + // password generation must succeed + panic(err) + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return nil, err + } + shouldUpdatePassword = true + } + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if shouldUpdatePassword { + if terr = user.UpdatePassword(tx, nil); terr != nil { + return internalServerError("Error storing password").WithInternalError(terr) + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + + if terr = user.Confirm(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + + for _, identity := range user.Identities { + if identity.Email == "" || user.Email == "" || identity.Email != user.Email { + continue + } + + if terr = identity.UpdateIdentityData(tx, map[string]interface{}{ + "email_verified": true, + }); terr != nil { + return internalServerError("Error setting email_verified to true on identity").WithInternalError(terr) + } + } + + return nil + }) + if err != nil { + return nil, err + } + return user, nil +} + +func (a *API) recoverVerify(r *http.Request, conn *storage.Connection, user *models.User) (*models.User, error) { + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = user.Recover(tx); terr != nil { + return terr + } + if !user.IsConfirmed() { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + + if terr = user.Confirm(tx); terr != nil { + return terr + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", nil); terr != nil { + return terr + } + } + return nil + }) + + if err != nil { + return nil, internalServerError("Database error updating user").WithInternalError(err) + } + return user, nil +} + +func (a *API) smsVerify(r *http.Request, conn *storage.Connection, user *models.User, params *VerifyParams) (*models.User, error) { + + err := conn.Transaction(func(tx *storage.Connection) error { + + if params.Type == smsVerification { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + if terr := user.ConfirmPhone(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + } else if params.Type == phoneChangeVerification { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return terr + } + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "phone"); terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + // confirming the phone change should create a new phone identity if the user doesn't have one + if _, terr = a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: params.Phone, + PhoneVerified: true, + })); terr != nil { + return terr + } + } else { + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "phone": params.Phone, + "phone_verified": true, + }); terr != nil { + return terr + } + } + if terr := user.ConfirmPhoneChange(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + } + + if user.IsAnonymous { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + + if terr := tx.Load(user, "Identities"); terr != nil { + return internalServerError("Error refetching identities").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + return user, nil +} + +func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, flowType models.FlowType) (string, error) { + u, perr := url.Parse(rurl) + if perr != nil { + return "", err + } + q := u.Query() + + // Maintain separate query params for hash and query + hq := url.Values{} + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) + err.ErrorID = errorID + log.WithError(err.Cause()).Info(err.Error()) + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { + hq.Set("error", str) + q.Set("error", str) + } + hq.Set("error_code", err.ErrorCode) + hq.Set("error_description", err.Message) + + q.Set("error_code", err.ErrorCode) + q.Set("error_description", err.Message) + if flowType == models.PKCEFlow { + // Additionally, may override existing error query param if set to PKCE. + u.RawQuery = q.Encode() + } + // Left as hash fragment to comply with spec. + u.Fragment = hq.Encode() + return u.String(), nil +} + +func (a *API) prepRedirectURL(message string, rurl string, flowType models.FlowType) (string, error) { + u, perr := url.Parse(rurl) + if perr != nil { + return "", perr + } + hq := url.Values{} + q := u.Query() + hq.Set("message", message) + if flowType == models.PKCEFlow { + q.Set("message", message) + } + u.RawQuery = q.Encode() + u.Fragment = hq.Encode() + return u.String(), nil +} + +func (a *API) prepPKCERedirectURL(rurl, code string) (string, error) { + u, err := url.Parse(rurl) + if err != nil { + return "", err + } + q := u.Query() + q.Set("code", code) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (a *API) emailChangeVerify(r *http.Request, conn *storage.Connection, params *VerifyParams, user *models.User) (*models.User, error) { + config := a.config + if !config.Mailer.Autoconfirm && + config.Mailer.SecureEmailChangeEnabled && + user.EmailChangeConfirmStatus == zeroConfirmation && + user.GetEmail() != "" { + err := conn.Transaction(func(tx *storage.Connection) error { + currentOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenCurrent) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + + newOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenNew) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + + user.EmailChangeConfirmStatus = singleConfirmation + + if params.Token == user.EmailChangeTokenCurrent || params.TokenHash == user.EmailChangeTokenCurrent || (currentOTT != nil && params.TokenHash == currentOTT.TokenHash) { + user.EmailChangeTokenCurrent = "" + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenCurrent); terr != nil { + return terr + } + } else if params.Token == user.EmailChangeTokenNew || params.TokenHash == user.EmailChangeTokenNew || (newOTT != nil && params.TokenHash == newOTT.TokenHash) { + user.EmailChangeTokenNew = "" + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenNew); terr != nil { + return terr + } + } + if terr := tx.UpdateOnly(user, "email_change_confirm_status", "email_change_token_current", "email_change_token_new"); terr != nil { + return terr + } + return nil + }) + if err != nil { + return nil, err + } + return nil, nil + } + + // one email is confirmed at this point if GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED is enabled + err := conn.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return terr + } + + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "email"); terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + // confirming the email change should create a new email identity if the user doesn't have one + if _, terr = a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.EmailChange, + EmailVerified: true, + })); terr != nil { + return terr + } + } else { + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "email": user.EmailChange, + "email_verified": true, + }); terr != nil { + return terr + } + } + if user.IsAnonymous { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + if terr := tx.Load(user, "Identities"); terr != nil { + return internalServerError("Error refetching identities").WithInternalError(terr) + } + if terr := user.ConfirmEmailChange(tx, zeroConfirmation); terr != nil { + return internalServerError("Error confirm email").WithInternalError(terr) + } + + return nil + }) + if err != nil { + return nil, err + } + + return user, nil +} + +func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (*models.User, error) { + config := a.config + + var user *models.User + var err error + switch params.Type { + case mail.EmailOTPVerification: + // need to find user by confirmation token or recovery token with the token hash + user, err = models.FindUserByConfirmationOrRecoveryToken(conn, params.TokenHash) + case mail.SignupVerification, mail.InviteVerification: + user, err = models.FindUserByConfirmationToken(conn, params.TokenHash) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, err = models.FindUserByRecoveryToken(conn, params.TokenHash) + case mail.EmailChangeVerification: + user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) + default: + return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email verification type") + } + + if err != nil { + if models.IsNotFoundError(err) { + return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) + } + return nil, internalServerError("Database error finding user from email link").WithInternalError(err) + } + + if user.IsBanned() { + return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + } + + var isExpired bool + switch params.Type { + case mail.EmailOTPVerification: + sentAt := user.ConfirmationSentAt + params.Type = "signup" + if user.RecoveryToken == params.TokenHash { + sentAt = user.RecoverySentAt + params.Type = "magiclink" + } + isExpired = isOtpExpired(sentAt, config.Mailer.OtpExp) + case mail.SignupVerification, mail.InviteVerification: + isExpired = isOtpExpired(user.ConfirmationSentAt, config.Mailer.OtpExp) + case mail.RecoveryVerification, mail.MagicLinkVerification: + isExpired = isOtpExpired(user.RecoverySentAt, config.Mailer.OtpExp) + case mail.EmailChangeVerification: + isExpired = isOtpExpired(user.EmailChangeSentAt, config.Mailer.OtpExp) + } + + if isExpired { + return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") + } + + return user, nil +} + +// verifyUserAndToken verifies the token associated to the user based on the verify type +func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, aud string) (*models.User, error) { + config := a.config + + var user *models.User + var err error + tokenHash := params.TokenHash + + switch params.Type { + case phoneChangeVerification: + user, err = models.FindUserByPhoneChangeAndAudience(conn, params.Phone, aud) + case smsVerification: + user, err = models.FindUserByPhoneAndAudience(conn, params.Phone, aud) + case mail.EmailChangeVerification: + // Since the email change could be trigger via the implicit or PKCE flow, + // the query used has to also check if the token saved in the db contains the pkce_ prefix + user, err = models.FindUserForEmailChange(conn, params.Email, tokenHash, aud, config.Mailer.SecureEmailChangeEnabled) + default: + user, err = models.FindUserByEmailAndAudience(conn, params.Email, aud) + } + + if err != nil { + if models.IsNotFoundError(err) { + return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return nil, internalServerError("Database error finding user").WithInternalError(err) + } + + if user.IsBanned() { + return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + } + + var isValid bool + + smsProvider, _ := sms_provider.GetSmsProvider(*config) + switch params.Type { + case mail.EmailOTPVerification: + // if the type is emailOTPVerification, we'll check both the confirmation_token and recovery_token columns + if isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp) { + isValid = true + params.Type = mail.SignupVerification + } else if isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp) { + isValid = true + params.Type = mail.MagicLinkVerification + } else { + isValid = false + } + case mail.SignupVerification, mail.InviteVerification: + isValid = isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp) + case mail.RecoveryVerification, mail.MagicLinkVerification: + isValid = isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp) + case mail.EmailChangeVerification: + isValid = isOtpValid(tokenHash, user.EmailChangeTokenCurrent, user.EmailChangeSentAt, config.Mailer.OtpExp) || + isOtpValid(tokenHash, user.EmailChangeTokenNew, user.EmailChangeSentAt, config.Mailer.OtpExp) + case phoneChangeVerification, smsVerification: + if testOTP, ok := config.Sms.GetTestOTP(params.Phone, time.Now()); ok { + if params.Token == testOTP { + return user, nil + } + } + + phone := params.Phone + sentAt := user.ConfirmationSentAt + expectedToken := user.ConfirmationToken + if params.Type == phoneChangeVerification { + phone = user.PhoneChange + sentAt = user.PhoneChangeSentAt + expectedToken = user.PhoneChangeToken + } + + if !config.Hook.SendSMS.Enabled && config.Sms.IsTwilioVerifyProvider() { + if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { + return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return user, nil + } + isValid = isOtpValid(tokenHash, expectedToken, sentAt, config.Sms.OtpExp) + } + + if !isValid { + return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + } + return user, nil +} + +// isOtpValid checks the actual otp sent against the expected otp and ensures that it's within the valid window +func isOtpValid(actual, expected string, sentAt *time.Time, otpExp uint) bool { + if expected == "" || sentAt == nil { + return false + } + return !isOtpExpired(sentAt, otpExp) && ((actual == expected) || ("pkce_"+actual == expected)) +} + +func isOtpExpired(sentAt *time.Time, otpExp uint) bool { + return time.Now().After(sentAt.Add(time.Second * time.Duration(otpExp))) // #nosec G115 +} + +// isPhoneOtpVerification checks if the verification came from a phone otp +func isPhoneOtpVerification(params *VerifyParams) bool { + return params.Phone != "" && params.Email == "" +} + +// isEmailOtpVerification checks if the verification came from an email otp +func isEmailOtpVerification(params *VerifyParams) bool { + return params.Phone == "" && params.Email != "" +} + +func isUsingTokenHash(params *VerifyParams) bool { + return params.TokenHash != "" && params.Token == "" && params.Phone == "" && params.Email == "" +} diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go new file mode 100644 index 000000000..4e45fa21c --- /dev/null +++ b/internal/api/verify_test.go @@ -0,0 +1,1281 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/supabase/auth/internal/api/apierrors" + mail "github.com/supabase/auth/internal/mailer" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type VerifyTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestVerify(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &VerifyTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *VerifyTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": "test@example.com", + "email_verified": false, + }) + require.NoError(ts.T(), err, "Error creating test identity model") + require.NoError(ts.T(), ts.API.db.Create(i), "Error saving new test identity") +} + +func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { + // modify config so we don't hit rate limit from requesting recovery twice in 60s + ts.Config.SMTP.MaxFrequency = 60 + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + testEmail := "test@example.com" + + cases := []struct { + desc string + body map[string]interface{} + isPKCE bool + }{ + { + desc: "Implict Flow Recovery", + body: map[string]interface{}{ + "email": testEmail, + }, + isPKCE: false, + }, + { + desc: "PKCE Flow", + body: map[string]interface{}{ + "email": testEmail, + // Code Challenge needs to be at least 43 characters long + "code_challenge": "6b151854-cc15-4e29-8db7-3d3a9f15b3066b151854-cc15-4e29-8db7-3d3a9f15b306", + "code_challenge_method": models.SHA256.String(), + }, + isPKCE: true, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Reset user + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + recoveryToken := u.RecoveryToken + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + if c.isPKCE { + rURL, _ := w.Result().Location() + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + } + }) + } +} + +func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { + currentEmail := "test@example.com" + newEmail := "new@example.com" + + // Change from new email to current email and back to new email + cases := []struct { + desc string + body map[string]interface{} + isPKCE bool + currentEmail string + newEmail string + }{ + { + desc: "Implict Flow Email Change", + body: map[string]interface{}{ + "email": newEmail, + }, + isPKCE: false, + currentEmail: currentEmail, + newEmail: newEmail, + }, + { + desc: "PKCE Email Change", + body: map[string]interface{}{ + "email": currentEmail, + // Code Challenge needs to be at least 43 characters long + "code_challenge": "6b151854-cc15-4e29-8db7-3d3a9f15b3066b151854-cc15-4e29-8db7-3d3a9f15b306", + "code_challenge_method": models.SHA256.String(), + }, + isPKCE: true, + currentEmail: newEmail, + newEmail: currentEmail, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // reset user + u.EmailChangeSentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Generate access token for request and a mock session + var token string + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + currentTokenHash := u.EmailChangeTokenCurrent + newTokenHash := u.EmailChangeTokenNew + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + // Verify new email + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + var v url.Values + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) + + // Verify old email + reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("access_token")) + ts.Require().NotEmpty(v.Get("expires_in")) + ts.Require().NotEmpty(v.Get("refresh_token")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("code")) + } + + // user's email should've been updated to newEmail + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + + // Reset confirmation status after each test + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + }) + } +} + +func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { + // verify variant testing not necessary in this test as it's testing + // the ConfirmationSentAt behavior, not the ConfirmationToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "asdf3" + sentTime := time.Now().Add(-48 * time.Hour) + u.ConfirmationSentAt = &sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + + // Setup request + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + rurl, err := url.Parse(w.Header().Get("Location")) + require.NoError(ts.T(), err, "redirect url parse failed") + + f, err := url.ParseQuery(rurl.Fragment) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), apierrors.ErrorCodeOTPExpired, f.Get("error_code")) + assert.Equal(ts.T(), "Email link is invalid or has expired", f.Get("error_description")) + assert.Equal(ts.T(), "access_denied", f.Get("error")) +} + +func (ts *VerifyTestSuite) TestInvalidOtp() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "12345678", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + sentTime := time.Now().Add(-48 * time.Hour) + u.ConfirmationToken = "123456" + u.ConfirmationSentAt = &sentTime + u.PhoneChange = "22222222" + u.PhoneChangeToken = "123456" + u.PhoneChangeSentAt = &sentTime + u.EmailChange = "test@gmail.com" + u.EmailChangeTokenNew = "123456" + u.EmailChangeTokenCurrent = "123456" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.PhoneChange, u.PhoneChangeToken, models.PhoneChangeToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + type ResponseBody struct { + Code int `json:"code"` + Msg string `json:"msg"` + } + + expectedResponse := ResponseBody{ + Code: http.StatusForbidden, + Msg: "Token has expired or is invalid", + } + + cases := []struct { + desc string + sentTime time.Time + body map[string]interface{} + expected ResponseBody + }{ + { + desc: "Expired SMS OTP", + sentTime: time.Now().Add(-48 * time.Hour), + body: map[string]interface{}{ + "type": smsVerification, + "token": u.ConfirmationToken, + "phone": u.GetPhone(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid SMS OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": smsVerification, + "token": "invalid_otp", + "phone": u.GetPhone(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "invalid_otp", + "phone": u.PhoneChange, + }, + expected: expectedResponse, + }, + { + desc: "Invalid Email OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token": "invalid_otp", + "email": u.GetEmail(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid Email Change", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token": "invalid_otp", + "email": u.GetEmail(), + }, + expected: expectedResponse, + }, + } + + for _, caseItem := range cases { + c := caseItem + + ts.Run(c.desc, func() { + // update token sent time + sentTime = time.Now() + u.ConfirmationSentAt = &c.sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + b, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + var resp ResponseBody + err = json.Unmarshal(b, &resp) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), c.expected.Code, resp.Code) + assert.Equal(ts.T(), c.expected.Msg, resp.Msg) + + }) + } +} + +func (ts *VerifyTestSuite) TestExpiredRecoveryToken() { + // verify variant testing not necessary in this test as it's testing + // the RecoverySentAt behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoveryToken = "asdf3" + sentTime := time.Now().Add(-48 * time.Hour) + u.RecoverySentAt = &sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Setup request + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", "signup", u.RecoveryToken) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusSeeOther, w.Code, w.Body.String()) +} + +func (ts *VerifyTestSuite) TestVerifyPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + redirectURL, _ := url.Parse(ts.Config.URIAllowList[0]) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, redirectURL.String()) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Equal(ts.T(), redirectURL.Hostname(), rURL.Hostname()) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) +} + +func (ts *VerifyTestSuite) TestVerifyNotPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + fakeredirectURL, _ := url.Parse("http://custom-url.com") + siteURL, _ := url.Parse(ts.Config.SiteURL) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, fakeredirectURL.String()) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Equal(ts.T(), siteURL.Hostname(), rURL.Hostname()) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) +} + +func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + testCases := []struct { + desc string + siteURL string + uriAllowList []string + requestredirectURL string + expectedredirectURL string + }{ + { + desc: "same site url and redirect url with path", + siteURL: "http://localhost:3000/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000/#/", + expectedredirectURL: "http://localhost:3000/#/", + }, + { + desc: "different site url and redirect url in allow list", + siteURL: "https://someapp-something.codemagic.app/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000", + expectedredirectURL: "http://localhost:3000", + }, + { + desc: "different site url and redirect url not in allow list", + siteURL: "https://someapp-something.codemagic.app/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000/docs", + expectedredirectURL: "https://someapp-something.codemagic.app/#/", + }, + { + desc: "same wildcard site url and redirect url in allow list", + siteURL: "http://sub.test.dev:3000/#/", + uriAllowList: []string{"http://*.test.dev:3000"}, + requestredirectURL: "http://sub.test.dev:3000/#/", + expectedredirectURL: "http://sub.test.dev:3000/#/", + }, + { + desc: "different wildcard site url and redirect url in allow list", + siteURL: "http://sub.test.dev/#/", + uriAllowList: []string{"http://*.other.dev:3000"}, + requestredirectURL: "http://sub.other.dev:3000", + expectedredirectURL: "http://sub.other.dev:3000", + }, + { + desc: "different wildcard site url and redirect url not in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"http://*.allowed.dev:3000"}, + requestredirectURL: "http://sub.test.dev:3000/#/", + expectedredirectURL: "http://test.dev:3000/#/", + }, + { + desc: "exact mobile deep link redirect url in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"twitter://timeline"}, + requestredirectURL: "twitter://timeline", + expectedredirectURL: "twitter://timeline", + }, + // previously the below example was not allowed and with good + // reason, however users do want flexibility in the redirect + // URL after the scheme, which is why the example is now corrected + { + desc: "wildcard mobile deep link redirect url in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"com.example.app://**"}, + requestredirectURL: "com.example.app://sign-in/v2", + expectedredirectURL: "com.example.app://sign-in/v2", + }, + { + desc: "redirect respects . separator", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://*.*.dev:3000"}, + requestredirectURL: "http://foo.bar.dev:3000", + expectedredirectURL: "http://foo.bar.dev:3000", + }, + { + desc: "redirect does not respect . separator", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://*.dev:3000"}, + requestredirectURL: "http://foo.bar.dev:3000", + expectedredirectURL: "http://localhost:3000", + }, + { + desc: "redirect respects / separator in url subdirectory", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://test.dev:3000/*/*"}, + requestredirectURL: "http://test.dev:3000/bar/foo", + expectedredirectURL: "http://test.dev:3000/bar/foo", + }, + { + desc: "redirect does not respect / separator in url subdirectory", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://test.dev:3000/*"}, + requestredirectURL: "http://test.dev:3000/bar/foo", + expectedredirectURL: "http://localhost:3000", + }, + } + + for _, tC := range testCases { + ts.Run(tC.desc, func() { + // prepare test data + ts.Config.SiteURL = tC.siteURL + redirectURL := tC.requestredirectURL + ts.Config.URIAllowList = tC.uriAllowList + ts.Config.ApplyDefaults() + + // set verify token to user as it actual do in magic link method + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "someToken" + sendTime := time.Now().Add(time.Hour) + u.ConfirmationSentAt = &sendTime + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "signup", u.ConfirmationToken, redirectURL) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Contains(ts.T(), rURL.String(), tC.expectedredirectURL) // redirected url starts with per test value + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + assert.True(ts.T(), u.UserMetaData["email_verified"].(bool)) + assert.True(ts.T(), u.Identities[0].IdentityData["email_verified"].(bool)) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now() + u.ConfirmationSentAt = &t + u.RecoverySentAt = &t + u.EmailChangeSentAt = &t + require.NoError(ts.T(), ts.API.db.Update(u)) + + cases := []struct { + desc string + payload *VerifyParams + authenticationMethod models.AuthenticationMethod + }{ + { + desc: "Verify user on signup", + payload: &VerifyParams{ + Type: "signup", + Token: "pkce_confirmation_token", + }, + authenticationMethod: models.EmailSignup, + }, + { + desc: "Verify magiclink", + payload: &VerifyParams{ + Type: "magiclink", + Token: "pkce_recovery_token", + }, + authenticationMethod: models.MagicLink, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + // since the test user is the same, the tokens are being cleared after each successful verification attempt + // so we create them on each run + if c.payload.Type == "signup" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.ConfirmationToken)) + } else if c.payload.Type == "magiclink" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.RecoveryToken)) + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) + codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) + require.NoError(ts.T(), ts.API.db.Create(flowState)) + + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) + req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + }) + } + +} + +func (ts *VerifyTestSuite) TestVerifyBannedUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "confirmation_token" + u.RecoveryToken = "recovery_token" + u.EmailChangeTokenCurrent = "current_email_change_token" + u.EmailChangeTokenNew = "new_email_change_token" + t := time.Now() + u.ConfirmationSentAt = &t + u.RecoverySentAt = &t + u.EmailChangeSentAt = &t + + t = time.Now().Add(24 * time.Hour) + u.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + cases := []struct { + desc string + payload *VerifyParams + }{ + { + desc: "Verify banned user on signup", + payload: &VerifyParams{ + Type: "signup", + Token: u.ConfirmationToken, + }, + }, + { + desc: "Verify banned user on invite", + payload: &VerifyParams{ + Type: "invite", + Token: u.ConfirmationToken, + }, + }, + { + desc: "Verify banned user on recover", + payload: &VerifyParams{ + Type: "recovery", + Token: u.RecoveryToken, + }, + }, + { + desc: "Verify banned user on magiclink", + payload: &VerifyParams{ + Type: "magiclink", + Token: u.RecoveryToken, + }, + }, + { + desc: "Verify banned user on email change", + payload: &VerifyParams{ + Type: "email_change", + Token: u.EmailChangeTokenCurrent, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) + + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) + req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + rurl, err := url.Parse(w.Header().Get("Location")) + require.NoError(ts.T(), err, "redirect url parse failed") + + f, err := url.ParseQuery(rurl.Fragment) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), apierrors.ErrorCodeUserBanned, f.Get("error_code")) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyValidOtp() { + ts.Config.Mailer.SecureEmailChangeEnabled = true + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.EmailChange = "new@example.com" + u.Phone = "12345678" + u.PhoneChange = "1234567890" + require.NoError(ts.T(), ts.API.db.Update(u)) + + type expected struct { + code int + tokenHash string + } + + cases := []struct { + desc string + sentTime time.Time + body map[string]interface{} + expected + }{ + { + desc: "Valid SMS OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": smsVerification, + "token": "123456", + "phone": u.GetPhone(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetPhone(), "123456"), + }, + }, + { + desc: "Valid Confirmation OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Signup Token Hash", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token_hash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Recovery OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.RecoveryVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email OTP (email casing shouldn't matter)", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token": "123456", + "email": strings.ToUpper(u.GetEmail()), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token": "123456", + "email": u.EmailChange, + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + }, + { + desc: "Valid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "123456", + "phone": u.PhoneChange, + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.PhoneChange, "123456"), + }, + }, + { + desc: "Valid Email Change Token Hash", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + }, + { + desc: "Valid Email Verification Type", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token_hash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + } + + for _, caseItem := range cases { + c := caseItem + ts.Run(c.desc, func() { + // create user + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + u.ConfirmationSentAt = &c.sentTime + u.RecoverySentAt = &c.sentTime + u.EmailChangeSentAt = &c.sentTime + u.PhoneChangeSentAt = &c.sentTime + + u.ConfirmationToken = c.expected.tokenHash + u.RecoveryToken = c.expected.tokenHash + u.EmailChangeTokenNew = c.expected.tokenHash + u.PhoneChangeToken = c.expected.tokenHash + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken)) + + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expected.code, w.Code) + }) + } +} + +func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { + ts.Config.Mailer.SecureEmailChangeEnabled = true + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.EmailChange = "new@example.com" + require.NoError(ts.T(), ts.API.db.Update(u)) + + currentEmailChangeToken := crypto.GenerateTokenHash(string(u.Email), "123456") + newEmailChangeToken := crypto.GenerateTokenHash(u.EmailChange, "123456") + + cases := []struct { + desc string + firstVerificationBody map[string]interface{} + secondVerificationBody map[string]interface{} + expectedStatus int + }{ + { + desc: "Secure Email Change with Token Hash (Success)", + firstVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + secondVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": newEmailChangeToken, + }, + expectedStatus: http.StatusOK, + }, + { + desc: "Secure Email Change with Token Hash. Reusing a token hash twice should fail", + firstVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + secondVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + expectedStatus: http.StatusForbidden, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Set the corresponding email change tokens + u.EmailChangeTokenCurrent = currentEmailChangeToken + u.EmailChangeTokenNew = newEmailChangeToken + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", currentEmailChangeToken, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", newEmailChangeToken, models.EmailChangeTokenNew)) + + currentTime := time.Now() + u.EmailChangeSentAt = ¤tTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) + + // Setup second request + req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup second response recorder + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expectedStatus, w.Code) + }) + } +} + +func (ts *VerifyTestSuite) TestPrepRedirectURL() { + escapedMessage := url.QueryEscape(singleConfirmationAccepted) + cases := []struct { + desc string + message string + rurl string + flowType models.FlowType + expected string + }{ + { + desc: "(PKCE): Redirect URL with additional query params", + message: singleConfirmationAccepted, + rurl: "https://example.com/?first=another&second=other", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?first=another&message=%s&second=other#message=%s", escapedMessage, escapedMessage), + }, + { + desc: "(PKCE): Query params in redirect url are overriden", + message: singleConfirmationAccepted, + rurl: "https://example.com/?message=Valid+redirect+URL", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?message=%s#message=%s", escapedMessage, escapedMessage), + }, + { + desc: "(Implicit): plain redirect url", + message: singleConfirmationAccepted, + rurl: "https://example.com/", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/#message=%s", escapedMessage), + }, + { + desc: "(Implicit): query params retained", + message: singleConfirmationAccepted, + rurl: "https://example.com/?first=another", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/?first=another#message=%s", escapedMessage), + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + rurl, err := ts.API.prepRedirectURL(c.message, c.rurl, c.flowType) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected, rurl) + }) + } +} + +func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { + const DefaultError = "Invalid redirect URL" + redirectError := fmt.Sprintf("error=invalid_request&error_code=validation_failed&error_description=%s", url.QueryEscape(DefaultError)) + + cases := []struct { + desc string + message string + rurl string + flowType models.FlowType + expected string + }{ + { + desc: "(PKCE): Error in both query params and hash fragment", + message: "Valid redirect URL", + rurl: "https://example.com/", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + }, + { + desc: "(PKCE): Error with conflicting query params in redirect url", + message: DefaultError, + rurl: "https://example.com/?error=Error+to+be+overriden", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + }, + { + desc: "(Implicit): plain redirect url", + message: DefaultError, + rurl: "https://example.com/", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/#%s", redirectError), + }, + { + desc: "(Implicit): query params preserved", + message: DefaultError, + rurl: "https://example.com/?test=param", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/?test=param#%s", redirectError), + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(apierrors.ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected, rurl) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyValidateParams() { + cases := []struct { + desc string + params *VerifyParams + method string + expected error + }{ + { + desc: "Successful GET Verify", + params: &VerifyParams{ + Type: "signup", + Token: "some-token-hash", + }, + method: http.MethodGet, + expected: nil, + }, + { + desc: "Successful POST Verify (TokenHash)", + params: &VerifyParams{ + Type: "signup", + TokenHash: "some-token-hash", + }, + method: http.MethodPost, + expected: nil, + }, + { + desc: "Successful POST Verify (Token)", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + Email: "email@example.com", + }, + method: http.MethodPost, + expected: nil, + }, + // unsuccessful validations + { + desc: "Need to send email or phone number with token", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + }, + method: http.MethodPost, + expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), + }, + { + desc: "Cannot send both TokenHash and Token", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + TokenHash: "some-token-hash", + }, + method: http.MethodPost, + expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), + }, + { + desc: "No verification type specified", + params: &VerifyParams{ + Token: "some-token", + Email: "email@example.com", + }, + method: http.MethodPost, + expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(c.method, "http://localhost", nil) + err := c.params.Validate(req, ts.API) + require.Equal(ts.T(), c.expected, err) + }) + } +} diff --git a/internal/api/web3.go b/internal/api/web3.go new file mode 100644 index 000000000..6dc2c7c81 --- /dev/null +++ b/internal/api/web3.go @@ -0,0 +1,170 @@ +package api + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "github.com/supabase/auth/internal/utilities/siws" +) + +type Web3GrantParams struct { + Message string `json:"message,omitempty"` + Signature string `json:"signature,omitempty"` + Chain string `json:"chain,omitempty"` +} + +func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + config := a.config + + if !config.External.Web3Solana.Enabled { + return unprocessableEntityError(apierrors.ErrorCodeWeb3ProviderDisabled, "Web3 provider is disabled") + } + + params := &Web3GrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.Chain != "solana" { + return badRequestError(apierrors.ErrorCodeWeb3UnsupportedChain, "Unsupported chain") + } + + return a.web3GrantSolana(ctx, w, r, params) +} + +func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *http.Request, params *Web3GrantParams) error { + config := a.config + db := a.db.WithContext(ctx) + + if len(params.Message) < 64 { + return badRequestError(apierrors.ErrorCodeValidationFailed, "message is too short") + } else if len(params.Message) > 20*1024 { + return badRequestError(apierrors.ErrorCodeValidationFailed, "message must not exceed 20KB") + } + + if len(params.Signature) != 86 && len(params.Signature) != 88 { + return badRequestError(apierrors.ErrorCodeValidationFailed, "signature must be 64 bytes encoded as base64 with or without padding") + } + + base64URLSignature := strings.ReplaceAll(strings.ReplaceAll(strings.TrimRight(params.Signature, "="), "+", "-"), "/", "_") + signatureBytes, err := base64.RawURLEncoding.DecodeString(base64URLSignature) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, "signature does not contain valid base64 characters") + } + + parsedMessage, err := siws.ParseMessage(params.Message) + if err != nil { + return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + } + + if !parsedMessage.VerifySignature(signatureBytes) { + return oauthError("invalid_grant", "Signature does not match address in message") + } + + if parsedMessage.URI.Scheme != "https" { + if parsedMessage.URI.Scheme == "http" && parsedMessage.URI.Hostname() != "localhost" { + return oauthError("invalid_grant", "Signed Solana message is using URI which uses HTTP and hostname is not localhost, only HTTPS is allowed") + } else { + return oauthError("invalid_grant", "Signed Solana message is using URI which does not use HTTPS") + } + } + + if !utilities.IsRedirectURLValid(config, parsedMessage.URI.String()) { + return oauthError("invalid_grant", "Signed Solana message is using URI which is not allowed on this server, message was signed for another app") + } + + if parsedMessage.URI.Host != parsedMessage.Domain || !utilities.IsRedirectURLValid(config, "https://"+parsedMessage.Domain+"/") { + return oauthError("invalid_grant", "Signed Solana message is using a Domain that does not match the one in URI which is not allowed on this server") + } + + now := a.Now() + + if !parsedMessage.NotBefore.IsZero() && now.Before(parsedMessage.NotBefore) { + return oauthError("invalid_grant", "Signed Solana message becomes valid in the future") + } + + if !parsedMessage.ExpirationTime.IsZero() && now.After(parsedMessage.ExpirationTime) { + return oauthError("invalid_grant", "Signed Solana message is expired") + } + + latestExpiryAt := parsedMessage.IssuedAt.Add(config.External.Web3Solana.MaximumValidityDuration) + + if now.After(latestExpiryAt) { + return oauthError("invalid_grant", "Solana message was issued too long ago") + } + + earliestIssuedAt := parsedMessage.IssuedAt.Add(-config.External.Web3Solana.MaximumValidityDuration) + + if now.Before(earliestIssuedAt) { + return oauthError("invalid_grant", "Solana message was issued too far in the future") + } + + providerId := strings.Join([]string{ + "web3", + params.Chain, + parsedMessage.Address, + }, ":") + + userData := provider.UserProvidedData{ + Metadata: &provider.Claims{ + CustomClaims: map[string]interface{}{ + "address": parsedMessage.Address, + "chain": params.Chain, + "network": parsedMessage.ChainID, + "domain": parsedMessage.Domain, + "statement": parsedMessage.Statement, + }, + Subject: providerId, + }, + Emails: []provider.Email{}, + } + + var token *AccessTokenResponse + var grantParams models.GrantParams + grantParams.FillGrantParams(r) + + err = db.Transaction(func(tx *storage.Connection) error { + user, terr := a.createAccountFromExternalIdentity(tx, r, &userData, "web3") + if terr != nil { + return terr + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": "web3", + "chain": params.Chain, + "network": parsedMessage.ChainID, + "address": parsedMessage.Address, + "domain": parsedMessage.Domain, + "uri": parsedMessage.URI, + }); terr != nil { + return terr + } + + token, terr = a.issueRefreshToken(r, tx, user, models.Web3, grantParams) + if terr != nil { + return terr + } + + return nil + }) + + if err != nil { + switch err.(type) { + case *storage.CommitWithError: + return err + case *HTTPError: + return err + default: + return oauthError("server_error", "Internal Server Error").WithInternalError(err) + } + } + + return sendJSON(w, http.StatusOK, token) +} diff --git a/internal/api/web3_test.go b/internal/api/web3_test.go new file mode 100644 index 000000000..42d3a00b7 --- /dev/null +++ b/internal/api/web3_test.go @@ -0,0 +1,607 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" +) + +type Web3TestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestWeb3(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &Web3TestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *Web3TestSuite) TestNonSolana() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "blockchain", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), apierrors.ErrorCodeWeb3UnsupportedChain, firstResult.ErrorCode) + assert.Equal(ts.T(), "Unsupported chain", firstResult.Message) +} + +func (ts *Web3TestSuite) TestDisabled() { + defer func() { + ts.Config.External.Web3Solana.Enabled = true + }() + + ts.Config.External.Web3Solana.Enabled = false + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), apierrors.ErrorCodeWeb3ProviderDisabled, firstResult.ErrorCode) + assert.Equal(ts.T(), "Web3 provider is disabled", firstResult.Message) +} + +func (ts *Web3TestSuite) TestHappyPath_FullMessage() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:09:59Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nExpiration Time: 2025-03-29T00:10:00Z\nNot Before: 2025-03-29T00:00:00Z", + "signature": "aiKn+PAoB1OoXxS8H34HrB456YD4sKAVjeTjsxgkaQy3bkdV51WBTmUUE9lBU9kuXr0hTLI+1aTn5TFRbIF8CA==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var firstResult struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.NotEmpty(ts.T(), firstResult.AccessToken) + assert.NotEmpty(ts.T(), firstResult.RefreshToken) +} + +func (ts *Web3TestSuite) TestHappyPath_MinimalMessage() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:09:59Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z", + "signature": "BQxBJ+g2xbMh0LqwYR4ULJ4l7jXFmz33urmp534MS0x7nrGRe2xYdFq41FiGrySX6RipzGqX4kS2vkQmi/+JCg==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var firstResult struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.NotEmpty(ts.T(), firstResult.AccessToken) + assert.NotEmpty(ts.T(), firstResult.RefreshToken) +} + +func (ts *Web3TestSuite) TestValidationRules_URINotHTTPSButIsHTTP() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: http://supaabse.com\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z", + "signature": "zkCDPRAgy3N6KaYJrFgoTGuR+DDn1T6WiC70/m4GSIKMN3rIIDRUHjX/+bDCRyPTq/nC8N9HkMUvoD86gpVKCw==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), firstResult.Error, "invalid_grant") + assert.Equal(ts.T(), firstResult.ErrorDescription, "Signed Solana message is using URI which uses HTTP and hostname is not localhost, only HTTPS is allowed") +} + +func (ts *Web3TestSuite) TestValidationRules_URINotAllowed() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.green wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.green/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nExpiration Time: 2025-03-29T00:10:00Z", + "signature": "HlwIlZNfJO2yVqnJfeTz1sEHEbU0pag5yyfWVjmoL6wAXNshOlmQCgbzM8AvdF3/JpeWru2FUsC9cKHchHStDw==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message is using URI which is not allowed on this server, message was signed for another app", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_URINotHTTPS() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: ftp://supaabse.com\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z", + "signature": "jalHCMtaGNUy5q7BIZRXjdtMJDVDk+ABj/bsIISdbzxc4bjt643llZfjQ3qJJmV1CsnNRgoIyVt8HmGHkIu9CA==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message is using URI which does not use HTTPS", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_InvalidDomain() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.green wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z", + "signature": "gB9SNz/fxpWir6ZV/oI3pJIYEce5FjSMkbHzDxMH7k6as2jYBVutMU50/UTH59jx3ULZeW3Xt7pDH+9qJCDjAQ==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message is using a Domain that does not match the one in URI which is not allowed on this server", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_MismatchedDomainAndURIHostname() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.green wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nExpiration Time: 2025-03-29T00:10:00Z", + "signature": "KmRa5LqZnwLE5c+PX45QBhuIY2AXWtD8zi3O5lROKJYho8iIt8vZaVo/2utQ5C77LWNL3nI42q/cC8N80hYKAw==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message is using a Domain that does not match the one in URI which is not allowed on this server", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_ValidatedBeforeNotBefore() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:59Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nNot Before: 2025-03-29T00:01:00Z", + "signature": "Pe2PpPEK+SIsO3i26SsWNHeFyLKNdcms4Gf7jy8GGR6EvPlWfKNwAtRGMnQa9MvQHgY7QmVOUDSKmYQlvU2sAA==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message becomes valid in the future", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_Expired() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:10:01Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nExpiration Time: 2025-03-29T00:10:00Z\nNot Before: 2025-03-29T00:00:00Z", + "signature": "aiKn+PAoB1OoXxS8H34HrB456YD4sKAVjeTjsxgkaQy3bkdV51WBTmUUE9lBU9kuXr0hTLI+1aTn5TFRbIF8CA==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Signed Solana message is expired", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_Future() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-28T23:49:59Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z", + "signature": "BQxBJ+g2xbMh0LqwYR4ULJ4l7jXFmz33urmp534MS0x7nrGRe2xYdFq41FiGrySX6RipzGqX4kS2vkQmi/+JCg==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Solana message was issued too far in the future", firstResult.ErrorDescription) +} + +func (ts *Web3TestSuite) TestValidationRules_IssedTooLongAgo() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + d, _ := time.ParseDuration("10m1s") + + return t.Add(d) + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nNot Before: 2025-03-29T00:00:00Z", + "signature": "ds3yyRoevZ0CuyUFOfuAJV/QAA+m302JJjnkOQO3ou5AHPQBNdbwYDj2JzF/5Ox6qyAqN/phU8NnmK8eUtzMDw==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), firstResult.Error, "invalid_grant") + assert.Equal(ts.T(), firstResult.ErrorDescription, "Solana message was issued too long ago") +} + +func (ts *Web3TestSuite) TestValidationRules_InvalidSignature() { + defer func() { + ts.API.overrideTime = nil + }() + + ts.API.overrideTime = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2025-03-29T00:00:00Z") + return t + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": "supabase.com wants you to sign in with your Solana account:\n2EZEiBdw47VHT6SpZSW9VnuSvBe7DxuYHBTxj19gxvv8\n\nStatement\n\nURI: https://supabase.com/\nVersion: 1\nIssued At: 2025-03-29T00:00:00Z\nExpiration Time: 2025-03-29T00:10:00Z\nNot Before: 2025-03-29T00:00:00Z", + "signature": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + + assert.Equal(ts.T(), firstResult.Error, "invalid_grant") + assert.Equal(ts.T(), firstResult.ErrorDescription, "Signature does not match address in message") +} + +func (ts *Web3TestSuite) TestValidationRules_BasicValidation() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 63), + "signature": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx==", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 64), + "signature": strings.Repeat("x", 85), + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 64), + "signature": strings.Repeat("x", 89), + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 20*1024+1), + "signature": strings.Repeat("x", 86), + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 64), + "signature": strings.Repeat("\x00", 86), + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "chain": "solana", + "message": strings.Repeat(" ", 64), + "signature": strings.Repeat("x", 86), + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go new file mode 100644 index 000000000..1961a96f2 --- /dev/null +++ b/internal/conf/configuration.go @@ -0,0 +1,1196 @@ +package conf + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" + "text/template" + "time" + + "github.com/gobwas/glob" + "github.com/golang-jwt/jwt/v5" + "github.com/joho/godotenv" + "github.com/kelseyhightower/envconfig" + "github.com/lestrrat-go/jwx/v2/jwk" + "gopkg.in/gomail.v2" +) + +const defaultMinPasswordLength int = 6 +const defaultChallengeExpiryDuration float64 = 300 +const defaultFactorExpiryDuration time.Duration = 300 * time.Second +const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second + +// See: https://www.postgresql.org/docs/7.0/syntax525.htm +var postgresNamesRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`) + +// See: https://github.com/standard-webhooks/standard-webhooks/blob/main/spec/standard-webhooks.md +// We use 4 * Math.ceil(n/3) to obtain unpadded length in base 64 +// So this 4 * Math.ceil(24/3) = 32 and 4 * Math.ceil(64/3) = 88 for symmetric secrets +// Since Ed25519 key is 32 bytes so we have 4 * Math.ceil(32/3) = 44 +var symmetricSecretFormat = regexp.MustCompile(`^v1,whsec_[A-Za-z0-9+/=]{32,88}`) +var asymmetricSecretFormat = regexp.MustCompile(`^v1a,whpk_[A-Za-z0-9+/=]{44,}:whsk_[A-Za-z0-9+/=]{44,}$`) + +// Time is used to represent timestamps in the configuration, as envconfig has +// trouble parsing empty strings, due to time.Time.UnmarshalText(). +type Time struct { + time.Time +} + +func (t *Time) UnmarshalText(text []byte) error { + trimed := bytes.TrimSpace(text) + + if len(trimed) < 1 { + t.Time = time.Time{} + } else { + if err := t.Time.UnmarshalText(trimed); err != nil { + return err + } + } + + return nil +} + +// OAuthProviderConfiguration holds all config related to external account providers. +type OAuthProviderConfiguration struct { + ClientID []string `json:"client_id" split_words:"true"` + Secret string `json:"secret"` + RedirectURI string `json:"redirect_uri" split_words:"true"` + URL string `json:"url"` + ApiURL string `json:"api_url" split_words:"true"` + Enabled bool `json:"enabled"` + SkipNonceCheck bool `json:"skip_nonce_check" split_words:"true"` +} + +type AnonymousProviderConfiguration struct { + Enabled bool `json:"enabled" default:"false"` +} + +type EmailProviderConfiguration struct { + Enabled bool `json:"enabled" default:"true"` + + AuthorizedAddresses []string `json:"authorized_addresses" split_words:"true"` + + MagicLinkEnabled bool `json:"magic_link_enabled" default:"true" split_words:"true"` +} + +// DBConfiguration holds all the database related configuration. +type DBConfiguration struct { + Driver string `json:"driver" required:"true"` + URL string `json:"url" envconfig:"DATABASE_URL" required:"true"` + Namespace string `json:"namespace" envconfig:"DB_NAMESPACE" default:"auth"` + // MaxPoolSize defaults to 0 (unlimited). + MaxPoolSize int `json:"max_pool_size" split_words:"true"` + MaxIdlePoolSize int `json:"max_idle_pool_size" split_words:"true"` + ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty" split_words:"true"` + ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty" split_words:"true"` + HealthCheckPeriod time.Duration `json:"health_check_period" split_words:"true"` + MigrationsPath string `json:"migrations_path" split_words:"true" default:"./migrations"` + CleanupEnabled bool `json:"cleanup_enabled" split_words:"true" default:"false"` +} + +func (c *DBConfiguration) Validate() error { + return nil +} + +// JWTConfiguration holds all the JWT related configuration. +type JWTConfiguration struct { + Secret string `json:"secret" required:"true"` + Exp int `json:"exp"` + Aud string `json:"aud"` + AdminGroupName string `json:"admin_group_name" split_words:"true"` + AdminRoles []string `json:"admin_roles" split_words:"true"` + DefaultGroupName string `json:"default_group_name" split_words:"true"` + Issuer string `json:"issuer"` + KeyID string `json:"key_id" split_words:"true"` + Keys JwtKeysDecoder `json:"keys"` + ValidMethods []string `json:"-"` +} + +type MFAFactorTypeConfiguration struct { + EnrollEnabled bool `json:"enroll_enabled" split_words:"true" default:"false"` + VerifyEnabled bool `json:"verify_enabled" split_words:"true" default:"false"` +} + +type TOTPFactorTypeConfiguration struct { + EnrollEnabled bool `json:"enroll_enabled" split_words:"true" default:"true"` + VerifyEnabled bool `json:"verify_enabled" split_words:"true" default:"true"` +} + +type PhoneFactorTypeConfiguration struct { + // Default to false in order to ensure Phone MFA is opt-in + MFAFactorTypeConfiguration + OtpLength int `json:"otp_length" split_words:"true"` + SMSTemplate *template.Template `json:"-"` + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + Template string `json:"template"` +} + +// MFAConfiguration holds all the MFA related Configuration +type MFAConfiguration struct { + ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` + FactorExpiryDuration time.Duration `json:"factor_expiry_duration" default:"300s" split_words:"true"` + RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` + MaxEnrolledFactors float64 `split_words:"true" default:"10"` + MaxVerifiedFactors int `split_words:"true" default:"10"` + Phone PhoneFactorTypeConfiguration `split_words:"true"` + TOTP TOTPFactorTypeConfiguration `split_words:"true"` + WebAuthn MFAFactorTypeConfiguration `split_words:"true"` +} + +type APIConfiguration struct { + Host string + Port string `envconfig:"PORT" default:"8081"` + Endpoint string + RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` + ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"` +} + +func (a *APIConfiguration) Validate() error { + _, err := url.ParseRequestURI(a.ExternalURL) + if err != nil { + return err + } + + return nil +} + +type SessionsConfiguration struct { + Timebox *time.Duration `json:"timebox"` + InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"` + + SinglePerUser bool `json:"single_per_user" split_words:"true"` + Tags []string `json:"tags,omitempty"` +} + +func (c *SessionsConfiguration) Validate() error { + if c.Timebox == nil { + return nil + } + + if *c.Timebox <= time.Duration(0) { + return fmt.Errorf("conf: session timebox duration must be positive when set, was %v", (*c.Timebox).String()) + } + + return nil +} + +type PasswordRequiredCharacters []string + +func (v *PasswordRequiredCharacters) Decode(value string) error { + parts := strings.Split(value, ":") + + for i := 0; i < len(parts)-1; i += 1 { + part := parts[i] + + if part == "" { + continue + } + + // part ended in escape character, so it should be joined with the next one + if part[len(part)-1] == '\\' { + parts[i] = part[0:len(part)-1] + ":" + parts[i+1] + parts[i+1] = "" + continue + } + } + + for _, part := range parts { + if part != "" { + *v = append(*v, part) + } + } + + return nil +} + +// HIBPBloomConfiguration configures a bloom cache for pwned passwords. Use +// this tool to gauge the Items and FalsePositives values: +// https://hur.st/bloomfilter +type HIBPBloomConfiguration struct { + Enabled bool `json:"enabled"` + Items uint `json:"items" default:"100000"` + FalsePositives float64 `json:"false_positives" split_words:"true" default:"0.0000099"` +} + +type HIBPConfiguration struct { + Enabled bool `json:"enabled"` + FailClosed bool `json:"fail_closed" split_words:"true"` + + UserAgent string `json:"user_agent" split_words:"true" default:"https://github.com/supabase/gotrue"` + + Bloom HIBPBloomConfiguration `json:"bloom"` +} + +type PasswordConfiguration struct { + MinLength int `json:"min_length" split_words:"true"` + + RequiredCharacters PasswordRequiredCharacters `json:"required_characters" split_words:"true"` + + HIBP HIBPConfiguration `json:"hibp"` +} + +// GlobalConfiguration holds all the configuration that applies to all instances. +type GlobalConfiguration struct { + API APIConfiguration + DB DBConfiguration + External ProviderConfiguration + Logging LoggingConfig `envconfig:"LOG"` + Profiler ProfilerConfig `envconfig:"PROFILER"` + OperatorToken string `split_words:"true" required:"false"` + Tracing TracingConfig + Metrics MetricsConfig + SMTP SMTPConfiguration + + RateLimitHeader string `split_words:"true"` + RateLimitEmailSent Rate `split_words:"true" default:"30"` + RateLimitSmsSent Rate `split_words:"true" default:"30"` + RateLimitVerify float64 `split_words:"true" default:"30"` + RateLimitTokenRefresh float64 `split_words:"true" default:"150"` + RateLimitSso float64 `split_words:"true" default:"30"` + RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` + RateLimitOtp float64 `split_words:"true" default:"30"` + + SiteURL string `json:"site_url" split_words:"true" required:"true"` + URIAllowList []string `json:"uri_allow_list" split_words:"true"` + URIAllowListMap map[string]glob.Glob + Password PasswordConfiguration `json:"password"` + JWT JWTConfiguration `json:"jwt"` + Mailer MailerConfiguration `json:"mailer"` + Sms SmsProviderConfiguration `json:"sms"` + DisableSignup bool `json:"disable_signup" split_words:"true"` + Hook HookConfiguration `json:"hook" split_words:"true"` + Security SecurityConfiguration `json:"security"` + Sessions SessionsConfiguration `json:"sessions"` + MFA MFAConfiguration `json:"MFA"` + SAML SAMLConfiguration `json:"saml"` + CORS CORSConfiguration `json:"cors"` +} + +type CORSConfiguration struct { + AllowedHeaders []string `json:"allowed_headers" split_words:"true"` +} + +func (c *CORSConfiguration) AllAllowedHeaders(defaults []string) []string { + set := make(map[string]bool) + for _, header := range defaults { + set[header] = true + } + + var result []string + result = append(result, defaults...) + + for _, header := range c.AllowedHeaders { + if !set[header] { + result = append(result, header) + } + + set[header] = true + } + + return result +} + +// EmailContentConfiguration holds the configuration for emails, both subjects and template URLs. +type EmailContentConfiguration struct { + Invite string `json:"invite"` + Confirmation string `json:"confirmation"` + Recovery string `json:"recovery"` + EmailChange string `json:"email_change" split_words:"true"` + MagicLink string `json:"magic_link" split_words:"true"` + Reauthentication string `json:"reauthentication"` +} + +type ProviderConfiguration struct { + AnonymousUsers AnonymousProviderConfiguration `json:"anonymous_users" split_words:"true"` + Apple OAuthProviderConfiguration `json:"apple"` + Azure OAuthProviderConfiguration `json:"azure"` + Bitbucket OAuthProviderConfiguration `json:"bitbucket"` + Discord OAuthProviderConfiguration `json:"discord"` + Facebook OAuthProviderConfiguration `json:"facebook"` + Figma OAuthProviderConfiguration `json:"figma"` + Fly OAuthProviderConfiguration `json:"fly"` + Github OAuthProviderConfiguration `json:"github"` + Gitlab OAuthProviderConfiguration `json:"gitlab"` + Google OAuthProviderConfiguration `json:"google"` + Kakao OAuthProviderConfiguration `json:"kakao"` + Notion OAuthProviderConfiguration `json:"notion"` + Keycloak OAuthProviderConfiguration `json:"keycloak"` + Linkedin OAuthProviderConfiguration `json:"linkedin"` + LinkedinOIDC OAuthProviderConfiguration `json:"linkedin_oidc" envconfig:"LINKEDIN_OIDC"` + Spotify OAuthProviderConfiguration `json:"spotify"` + Slack OAuthProviderConfiguration `json:"slack"` + SlackOIDC OAuthProviderConfiguration `json:"slack_oidc" envconfig:"SLACK_OIDC"` + Twitter OAuthProviderConfiguration `json:"twitter"` + Twitch OAuthProviderConfiguration `json:"twitch"` + VercelMarketplace OAuthProviderConfiguration `json:"vercel_marketplace" split_words:"true"` + WorkOS OAuthProviderConfiguration `json:"workos"` + Email EmailProviderConfiguration `json:"email"` + Phone PhoneProviderConfiguration `json:"phone"` + Zoom OAuthProviderConfiguration `json:"zoom"` + IosBundleId string `json:"ios_bundle_id" split_words:"true"` + RedirectURL string `json:"redirect_url"` + AllowedIdTokenIssuers []string `json:"allowed_id_token_issuers" split_words:"true"` + FlowStateExpiryDuration time.Duration `json:"flow_state_expiry_duration" split_words:"true"` + + Web3Solana SolanaConfiguration `json:"web3_solana" split_words:"true"` +} + +type SolanaConfiguration struct { + Enabled bool `json:"enabled,omitempty" split_words:"true"` + MaximumValidityDuration time.Duration `json:"maximum_validity_duration,omitempty" default:"10m" split_words:"true"` +} + +type SMTPConfiguration struct { + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + Host string `json:"host"` + Port int `json:"port,omitempty" default:"587"` + User string `json:"user"` + Pass string `json:"pass,omitempty"` + AdminEmail string `json:"admin_email" split_words:"true"` + SenderName string `json:"sender_name" split_words:"true"` + Headers string `json:"headers"` + LoggingEnabled bool `json:"logging_enabled" split_words:"true" default:"false"` + + fromAddress string `json:"-"` + normalizedHeaders map[string][]string `json:"-"` +} + +func (c *SMTPConfiguration) Validate() error { + headers := make(map[string][]string) + + if c.Headers != "" { + err := json.Unmarshal([]byte(c.Headers), &headers) + if err != nil { + return fmt.Errorf("conf: SMTP headers not a map[string][]string format: %w", err) + } + } + + if len(headers) > 0 { + c.normalizedHeaders = headers + } + + mail := gomail.NewMessage() + + c.fromAddress = mail.FormatAddress(c.AdminEmail, c.SenderName) + + return nil +} + +func (c *SMTPConfiguration) FromAddress() string { + return c.fromAddress +} + +func (c *SMTPConfiguration) NormalizedHeaders() map[string][]string { + return c.normalizedHeaders +} + +type MailerConfiguration struct { + Autoconfirm bool `json:"autoconfirm"` + AllowUnverifiedEmailSignIns bool `json:"allow_unverified_email_sign_ins" split_words:"true" default:"false"` + + Subjects EmailContentConfiguration `json:"subjects"` + Templates EmailContentConfiguration `json:"templates"` + URLPaths EmailContentConfiguration `json:"url_paths"` + + SecureEmailChangeEnabled bool `json:"secure_email_change_enabled" split_words:"true" default:"true"` + + OtpExp uint `json:"otp_exp" split_words:"true"` + OtpLength int `json:"otp_length" split_words:"true"` + + ExternalHosts []string `json:"external_hosts" split_words:"true"` + + // EXPERIMENTAL: May be removed in a future release. + EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"` + EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"` + EmailValidationServiceHeaders string `json:"email_validation_service_headers" split_words:"true"` + EmailValidationBlockedMX string `json:"email_validation_blocked_mx" split_words:"true"` + + serviceHeaders map[string][]string `json:"-"` + blockedMXRecords map[string]bool `json:"-"` +} + +func (c *MailerConfiguration) Validate() error { + headers := make(map[string][]string) + + if c.EmailValidationServiceHeaders != "" { + err := json.Unmarshal([]byte(c.EmailValidationServiceHeaders), &headers) + if err != nil { + return fmt.Errorf("conf: mailer validation headers not a map[string][]string format: %w", err) + } + } + + if len(headers) > 0 { + c.serviceHeaders = headers + } + + // EmailValidationBlockedMX is a JSON array in the config string for brevity. + var blockedMXRecords map[string]bool + if c.EmailValidationBlockedMX != "" { + var blockedMXArray []string + err := json.Unmarshal([]byte(c.EmailValidationBlockedMX), &blockedMXArray) + if err != nil { + return fmt.Errorf("conf: email_validation_blocked_mx is not a valid JSON array: %w", err) + } + blockedMXRecords = make(map[string]bool, len(blockedMXArray)*2) + for _, record := range blockedMXArray { + blockedMXRecords[record] = true + blockedMXRecords[record+"."] = true + } + } + + c.blockedMXRecords = blockedMXRecords + + return nil +} + +func (c *MailerConfiguration) GetEmailValidationServiceHeaders() map[string][]string { + return c.serviceHeaders +} + +func (c *MailerConfiguration) GetEmailValidationBlockedMXRecords() map[string]bool { + return c.blockedMXRecords +} + +type PhoneProviderConfiguration struct { + Enabled bool `json:"enabled" default:"false"` +} + +type SmsProviderConfiguration struct { + Autoconfirm bool `json:"autoconfirm"` + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + OtpExp uint `json:"otp_exp" split_words:"true"` + OtpLength int `json:"otp_length" split_words:"true"` + Provider string `json:"provider"` + Template string `json:"template"` + TestOTP map[string]string `json:"test_otp" split_words:"true"` + TestOTPValidUntil Time `json:"test_otp_valid_until" split_words:"true"` + SMSTemplate *template.Template `json:"-"` + + Twilio TwilioProviderConfiguration `json:"twilio"` + TwilioVerify TwilioVerifyProviderConfiguration `json:"twilio_verify" split_words:"true"` + Messagebird MessagebirdProviderConfiguration `json:"messagebird"` + Textlocal TextlocalProviderConfiguration `json:"textlocal"` + Vonage VonageProviderConfiguration `json:"vonage"` +} + +func (c *SmsProviderConfiguration) GetTestOTP(phone string, now time.Time) (string, bool) { + if c.TestOTP != nil && (c.TestOTPValidUntil.Time.IsZero() || now.Before(c.TestOTPValidUntil.Time)) { + testOTP, ok := c.TestOTP[phone] + return testOTP, ok + } + + return "", false +} + +type TwilioProviderConfiguration struct { + AccountSid string `json:"account_sid" split_words:"true"` + AuthToken string `json:"auth_token" split_words:"true"` + MessageServiceSid string `json:"message_service_sid" split_words:"true"` + ContentSid string `json:"content_sid" split_words:"true"` +} + +type TwilioVerifyProviderConfiguration struct { + AccountSid string `json:"account_sid" split_words:"true"` + AuthToken string `json:"auth_token" split_words:"true"` + MessageServiceSid string `json:"message_service_sid" split_words:"true"` +} + +type MessagebirdProviderConfiguration struct { + AccessKey string `json:"access_key" split_words:"true"` + Originator string `json:"originator" split_words:"true"` +} + +type TextlocalProviderConfiguration struct { + ApiKey string `json:"api_key" split_words:"true"` + Sender string `json:"sender" split_words:"true"` +} + +type VonageProviderConfiguration struct { + ApiKey string `json:"api_key" split_words:"true"` + ApiSecret string `json:"api_secret" split_words:"true"` + From string `json:"from" split_words:"true"` +} + +type CaptchaConfiguration struct { + Enabled bool `json:"enabled" default:"false"` + Provider string `json:"provider" default:"hcaptcha"` + Secret string `json:"provider_secret"` +} + +func (c *CaptchaConfiguration) Validate() error { + if !c.Enabled { + return nil + } + + if c.Provider != "hcaptcha" && c.Provider != "turnstile" { + return fmt.Errorf("unsupported captcha provider: %s", c.Provider) + } + + c.Secret = strings.TrimSpace(c.Secret) + + if c.Secret == "" { + return errors.New("captcha provider secret is empty") + } + + return nil +} + +// DatabaseEncryptionConfiguration configures Auth to encrypt certain columns. +// Once Encrypt is set to true, data will start getting encrypted with the +// provided encryption key. Setting it to false just stops encryption from +// going on further, but DecryptionKeys would have to contain the same key so +// the encrypted data remains accessible. +type DatabaseEncryptionConfiguration struct { + Encrypt bool `json:"encrypt"` + + EncryptionKeyID string `json:"encryption_key_id" split_words:"true"` + EncryptionKey string `json:"-" split_words:"true"` + + DecryptionKeys map[string]string `json:"-" split_words:"true"` +} + +func (c *DatabaseEncryptionConfiguration) Validate() error { + if c.Encrypt { + if c.EncryptionKeyID == "" { + return errors.New("conf: encryption key ID must be specified") + } + + decodedKey, err := base64.RawURLEncoding.DecodeString(c.EncryptionKey) + if err != nil { + return err + } + + if len(decodedKey) != 256/8 { + return errors.New("conf: encryption key is not 256 bits") + } + + if c.DecryptionKeys == nil || c.DecryptionKeys[c.EncryptionKeyID] == "" { + return errors.New("conf: encryption key must also be present in decryption keys") + } + } + + for id, key := range c.DecryptionKeys { + decodedKey, err := base64.RawURLEncoding.DecodeString(key) + if err != nil { + return err + } + + if len(decodedKey) != 256/8 { + return fmt.Errorf("conf: decryption key with ID %q must be 256 bits", id) + } + } + + return nil +} + +type SecurityConfiguration struct { + Captcha CaptchaConfiguration `json:"captcha"` + RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` + RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` + UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` + ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` + + DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"` +} + +func (c *SecurityConfiguration) Validate() error { + if err := c.Captcha.Validate(); err != nil { + return err + } + + if err := c.DBEncryption.Validate(); err != nil { + return err + } + + return nil +} + +func loadEnvironment(filename string) error { + var err error + if filename != "" { + err = godotenv.Overload(filename) + } else { + err = godotenv.Load() + // handle if .env file does not exist, this is OK + if os.IsNotExist(err) { + return nil + } + } + return err +} + +// Moving away from the existing HookConfig so we can get a fresh start. +type HookConfiguration struct { + MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"` + PasswordVerificationAttempt ExtensibilityPointConfiguration `json:"password_verification_attempt" split_words:"true"` + CustomAccessToken ExtensibilityPointConfiguration `json:"custom_access_token" split_words:"true"` + SendEmail ExtensibilityPointConfiguration `json:"send_email" split_words:"true"` + SendSMS ExtensibilityPointConfiguration `json:"send_sms" split_words:"true"` +} + +type HTTPHookSecrets []string + +func (h *HTTPHookSecrets) Decode(value string) error { + parts := strings.Split(value, "|") + for _, part := range parts { + if part != "" { + *h = append(*h, part) + } + } + + return nil +} + +type ExtensibilityPointConfiguration struct { + URI string `json:"uri"` + Enabled bool `json:"enabled"` + // For internal use together with Postgres Hook. Not publicly exposed. + HookName string `json:"-"` + // We use | as a separator for keys and : as a separator for keys within a keypair. For instance: v1,whsec_test|v1a,whpk_myother:v1a,whsk_testkey|v1,whsec_secret3 + HTTPHookSecrets HTTPHookSecrets `json:"secrets" envconfig:"secrets"` +} + +func (h *HookConfiguration) Validate() error { + points := []ExtensibilityPointConfiguration{ + h.MFAVerificationAttempt, + h.PasswordVerificationAttempt, + h.CustomAccessToken, + h.SendSMS, + h.SendEmail, + } + for _, point := range points { + if err := point.ValidateExtensibilityPoint(); err != nil { + return err + } + } + return nil +} + +func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { + if e.URI == "" { + return nil + } + u, err := url.Parse(e.URI) + if err != nil { + return err + } + switch strings.ToLower(u.Scheme) { + case "pg-functions": + return validatePostgresPath(u) + case "http": + hostname := u.Hostname() + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "host.docker.internal" { + return validateHTTPHookSecrets(e.HTTPHookSecrets) + } + return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") + case "https": + return validateHTTPHookSecrets(e.HTTPHookSecrets) + default: + return fmt.Errorf("only postgres hooks and HTTPS functions are supported at the moment") + } +} + +func validatePostgresPath(u *url.URL) error { + pathParts := strings.Split(u.Path, "/") + if len(pathParts) < 3 { + return fmt.Errorf("URI path does not contain enough parts") + } + + schema := pathParts[1] + table := pathParts[2] + // Validate schema and table names + if !postgresNamesRegexp.MatchString(schema) { + return fmt.Errorf("invalid schema name: %s", schema) + } + if !postgresNamesRegexp.MatchString(table) { + return fmt.Errorf("invalid table name: %s", table) + } + return nil +} + +func isValidSecretFormat(secret string) bool { + return symmetricSecretFormat.MatchString(secret) || asymmetricSecretFormat.MatchString(secret) +} + +func validateHTTPHookSecrets(secrets []string) error { + for _, secret := range secrets { + if !isValidSecretFormat(secret) { + return fmt.Errorf("invalid secret format") + } + } + return nil +} + +func (e *ExtensibilityPointConfiguration) PopulateExtensibilityPoint() error { + u, err := url.Parse(e.URI) + if err != nil { + return err + } + if u.Scheme == "pg-functions" { + pathParts := strings.Split(u.Path, "/") + e.HookName = fmt.Sprintf("%q.%q", pathParts[1], pathParts[2]) + } + return nil +} + +// LoadFile calls godotenv.Load() when the given filename is empty ignoring any +// errors loading, otherwise it calls godotenv.Overload(filename). +// +// godotenv.Load: preserves env, ".env" path is optional +// godotenv.Overload: overrides env, "filename" path must exist +func LoadFile(filename string) error { + var err error + if filename != "" { + err = godotenv.Overload(filename) + } else { + err = godotenv.Load() + // handle if .env file does not exist, this is OK + if os.IsNotExist(err) { + return nil + } + } + return err +} + +// LoadDirectory does nothing when configDir is empty, otherwise it will attempt +// to load a list of configuration files located in configDir by using ReadDir +// to obtain a sorted list of files containing a .env suffix. +// +// When the list is empty it will do nothing, otherwise it passes the file list +// to godotenv.Overload to pull them into the current environment. +func LoadDirectory(configDir string) error { + if configDir == "" { + return nil + } + + // Returns entries sorted by filename + ents, err := os.ReadDir(configDir) + if err != nil { + // We mimic the behavior of LoadGlobal here, if an explicit path is + // provided we return an error. + return err + } + + var paths []string + for _, ent := range ents { + if ent.IsDir() { + continue // ignore directories + } + + // We only read files ending in .env + name := ent.Name() + if !strings.HasSuffix(name, ".env") { + continue + } + + // ent.Name() does not include the watch dir. + paths = append(paths, filepath.Join(configDir, name)) + } + + // If at least one path was found we load the configuration files in the + // directory. We don't call override without config files because it will + // override the env vars previously set with a ".env", if one exists. + return loadDirectoryPaths(paths...) +} + +func loadDirectoryPaths(p ...string) error { + // If at least one path was found we load the configuration files in the + // directory. We don't call override without config files because it will + // override the env vars previously set with a ".env", if one exists. + if len(p) > 0 { + if err := godotenv.Overload(p...); err != nil { + return err + } + } + return nil +} + +// LoadGlobalFromEnv will return a new *GlobalConfiguration value from the +// currently configured environment. +func LoadGlobalFromEnv() (*GlobalConfiguration, error) { + config := new(GlobalConfiguration) + if err := loadGlobal(config); err != nil { + return nil, err + } + return config, nil +} + +func LoadGlobal(filename string) (*GlobalConfiguration, error) { + if err := loadEnvironment(filename); err != nil { + return nil, err + } + + config := new(GlobalConfiguration) + if err := loadGlobal(config); err != nil { + return nil, err + } + return config, nil +} + +func loadGlobal(config *GlobalConfiguration) error { + // although the package is called "auth" it used to be called "gotrue" + // so environment configs will remain to be called "GOTRUE" + if err := envconfig.Process("gotrue", config); err != nil { + return err + } + + if err := config.ApplyDefaults(); err != nil { + return err + } + + if err := config.Validate(); err != nil { + return err + } + return populateGlobal(config) +} + +func populateGlobal(config *GlobalConfiguration) error { + if config.Hook.PasswordVerificationAttempt.Enabled { + if err := config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.SendSMS.Enabled { + if err := config.Hook.SendSMS.PopulateExtensibilityPoint(); err != nil { + return err + } + } + if config.Hook.SendEmail.Enabled { + if err := config.Hook.SendEmail.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.MFAVerificationAttempt.Enabled { + if err := config.Hook.MFAVerificationAttempt.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.CustomAccessToken.Enabled { + if err := config.Hook.CustomAccessToken.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.SAML.Enabled { + if err := config.SAML.PopulateFields(config.API.ExternalURL); err != nil { + return err + } + } else { + config.SAML.PrivateKey = "" + } + + if config.Sms.Provider != "" { + SMSTemplate := config.Sms.Template + if SMSTemplate == "" { + SMSTemplate = "Your code is {{ .Code }}" + } + template, err := template.New("").Parse(SMSTemplate) + if err != nil { + return err + } + config.Sms.SMSTemplate = template + } + + if config.MFA.Phone.EnrollEnabled || config.MFA.Phone.VerifyEnabled { + smsTemplate := config.MFA.Phone.Template + if smsTemplate == "" { + smsTemplate = "Your code is {{ .Code }}" + } + template, err := template.New("").Parse(smsTemplate) + if err != nil { + return err + } + config.MFA.Phone.SMSTemplate = template + } + + return nil +} + +// ApplyDefaults sets defaults for a GlobalConfiguration +func (config *GlobalConfiguration) ApplyDefaults() error { + if config.JWT.AdminGroupName == "" { + config.JWT.AdminGroupName = "admin" + } + + if len(config.JWT.AdminRoles) == 0 { + config.JWT.AdminRoles = []string{"service_role", "supabase_admin"} + } + + if config.JWT.Exp == 0 { + config.JWT.Exp = 3600 + } + + if len(config.JWT.Keys) == 0 { + // transform the secret into a JWK for consistency + if err := config.applyDefaultsJWT([]byte(config.JWT.Secret)); err != nil { + return err + } + } + + if config.JWT.ValidMethods == nil { + config.JWT.ValidMethods = []string{} + for _, key := range config.JWT.Keys { + alg := GetSigningAlg(key.PublicKey) + config.JWT.ValidMethods = append(config.JWT.ValidMethods, alg.Alg()) + } + + } + + if config.Mailer.Autoconfirm && config.Mailer.AllowUnverifiedEmailSignIns { + return errors.New("cannot enable both GOTRUE_MAILER_AUTOCONFIRM and GOTRUE_MAILER_ALLOW_UNVERIFIED_EMAIL_SIGN_INS") + } + + if config.Mailer.URLPaths.Invite == "" { + config.Mailer.URLPaths.Invite = "/verify" + } + + if config.Mailer.URLPaths.Confirmation == "" { + config.Mailer.URLPaths.Confirmation = "/verify" + } + + if config.Mailer.URLPaths.Recovery == "" { + config.Mailer.URLPaths.Recovery = "/verify" + } + + if config.Mailer.URLPaths.EmailChange == "" { + config.Mailer.URLPaths.EmailChange = "/verify" + } + + if config.Mailer.OtpExp == 0 { + config.Mailer.OtpExp = 86400 // 1 day + } + + if config.Mailer.OtpLength == 0 || config.Mailer.OtpLength < 6 || config.Mailer.OtpLength > 10 { + // 6-digit otp by default + config.Mailer.OtpLength = 6 + } + + if config.SMTP.MaxFrequency == 0 { + config.SMTP.MaxFrequency = 1 * time.Minute + } + + if config.Sms.MaxFrequency == 0 { + config.Sms.MaxFrequency = 1 * time.Minute + } + + if config.Sms.OtpExp == 0 { + config.Sms.OtpExp = 60 + } + + if config.Sms.OtpLength == 0 || config.Sms.OtpLength < 6 || config.Sms.OtpLength > 10 { + // 6-digit otp by default + config.Sms.OtpLength = 6 + } + + if config.Sms.TestOTP != nil { + formatTestOtps := make(map[string]string) + for phone, otp := range config.Sms.TestOTP { + phone = strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") + formatTestOtps[phone] = otp + } + config.Sms.TestOTP = formatTestOtps + } + + if len(config.Sms.Template) == 0 { + config.Sms.Template = "" + } + + if config.URIAllowList == nil { + config.URIAllowList = []string{} + } + + if config.URIAllowList != nil { + config.URIAllowListMap = make(map[string]glob.Glob) + for _, uri := range config.URIAllowList { + g := glob.MustCompile(uri, '.', '/') + config.URIAllowListMap[uri] = g + } + } + + if config.Password.MinLength < defaultMinPasswordLength { + config.Password.MinLength = defaultMinPasswordLength + } + + if config.MFA.ChallengeExpiryDuration < defaultChallengeExpiryDuration { + config.MFA.ChallengeExpiryDuration = defaultChallengeExpiryDuration + } + + if config.MFA.FactorExpiryDuration < defaultFactorExpiryDuration { + config.MFA.FactorExpiryDuration = defaultFactorExpiryDuration + } + + if config.MFA.Phone.MaxFrequency == 0 { + config.MFA.Phone.MaxFrequency = 1 * time.Minute + } + + if config.MFA.Phone.OtpLength < 6 || config.MFA.Phone.OtpLength > 10 { + // 6-digit otp by default + config.MFA.Phone.OtpLength = 6 + } + + if config.External.FlowStateExpiryDuration < defaultFlowStateExpiryDuration { + config.External.FlowStateExpiryDuration = defaultFlowStateExpiryDuration + } + + if len(config.External.AllowedIdTokenIssuers) == 0 { + config.External.AllowedIdTokenIssuers = append(config.External.AllowedIdTokenIssuers, "https://appleid.apple.com", "https://accounts.google.com") + } + + return nil +} +func (config *GlobalConfiguration) applyDefaultsJWT(secret []byte) error { + // transform the secret into a JWK for consistency + privKey, err := jwk.FromRaw(secret) + if err != nil { + return err + } + return config.applyDefaultsJWTPrivateKey(privKey) +} + +func (config *GlobalConfiguration) applyDefaultsJWTPrivateKey(privKey jwk.Key) error { + if config.JWT.KeyID != "" { + if err := privKey.Set(jwk.KeyIDKey, config.JWT.KeyID); err != nil { + return err + } + } + if privKey.Algorithm().String() == "" { + if err := privKey.Set(jwk.AlgorithmKey, jwt.SigningMethodHS256.Name); err != nil { + return err + } + } + if err := privKey.Set(jwk.KeyUsageKey, "sig"); err != nil { + return err + } + if len(privKey.KeyOps()) == 0 { + if err := privKey.Set(jwk.KeyOpsKey, jwk.KeyOperationList{jwk.KeyOpSign, jwk.KeyOpVerify}); err != nil { + return err + } + } + pubKey, err := privKey.PublicKey() + if err != nil { + return err + } + config.JWT.Keys = make(JwtKeysDecoder) + config.JWT.Keys[config.JWT.KeyID] = JwkInfo{ + PublicKey: pubKey, + PrivateKey: privKey, + } + return nil +} + +// Validate validates all of configuration. +func (c *GlobalConfiguration) Validate() error { + validatables := []interface { + Validate() error + }{ + &c.API, + &c.DB, + &c.Tracing, + &c.Metrics, + &c.SMTP, + &c.Mailer, + &c.SAML, + &c.Security, + &c.Sessions, + &c.Hook, + &c.JWT.Keys, + } + + for _, validatable := range validatables { + if err := validatable.Validate(); err != nil { + return err + } + } + + return nil +} + +func (o *OAuthProviderConfiguration) ValidateOAuth() error { + if !o.Enabled { + return errors.New("provider is not enabled") + } + if len(o.ClientID) == 0 { + return errors.New("missing OAuth client ID") + } + if o.Secret == "" { + return errors.New("missing OAuth secret") + } + if o.RedirectURI == "" { + return errors.New("missing redirect URI") + } + return nil +} + +func (t *TwilioProviderConfiguration) Validate() error { + if t.AccountSid == "" { + return errors.New("missing Twilio account SID") + } + if t.AuthToken == "" { + return errors.New("missing Twilio auth token") + } + if t.MessageServiceSid == "" { + return errors.New("missing Twilio message service SID or Twilio phone number") + } + return nil +} + +func (t *TwilioVerifyProviderConfiguration) Validate() error { + if t.AccountSid == "" { + return errors.New("missing Twilio account SID") + } + if t.AuthToken == "" { + return errors.New("missing Twilio auth token") + } + if t.MessageServiceSid == "" { + return errors.New("missing Twilio message service SID or Twilio phone number") + } + return nil +} + +func (t *MessagebirdProviderConfiguration) Validate() error { + if t.AccessKey == "" { + return errors.New("missing Messagebird access key") + } + if t.Originator == "" { + return errors.New("missing Messagebird originator") + } + return nil +} + +func (t *TextlocalProviderConfiguration) Validate() error { + if t.ApiKey == "" { + return errors.New("missing Textlocal API key") + } + if t.Sender == "" { + return errors.New("missing Textlocal sender") + } + return nil +} + +func (t *VonageProviderConfiguration) Validate() error { + if t.ApiKey == "" { + return errors.New("missing Vonage API key") + } + if t.ApiSecret == "" { + return errors.New("missing Vonage API secret") + } + if t.From == "" { + return errors.New("missing Vonage 'from' parameter") + } + return nil +} + +func (t *SmsProviderConfiguration) IsTwilioVerifyProvider() bool { + return t.Provider == "twilio_verify" +} diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go new file mode 100644 index 000000000..8c7dcf880 --- /dev/null +++ b/internal/conf/configuration_test.go @@ -0,0 +1,1166 @@ +package conf + +import ( + "encoding/base64" + "errors" + "os" + "sort" + "strings" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + defer os.Clearenv() + os.Exit(m.Run()) +} + +func TestGlobal(t *testing.T) { + os.Setenv("GOTRUE_SITE_URL", "http://localhost:8080") + os.Setenv("GOTRUE_DB_DRIVER", "postgres") + os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") + os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") + os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") + os.Setenv("GOTRUE_JWT_SECRET", "secret") + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + os.Setenv("GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI", "pg-functions://postgres/auth/count_failed_attempts") + os.Setenv("GOTRUE_HOOK_SEND_SMS_SECRETS", "v1,whsec_aWxpa2VzdXBhYmFzZXZlcnltdWNoYW5kaWhvcGV5b3Vkb3Rvbw==") + os.Setenv("GOTRUE_SMTP_HEADERS", `{"X-PM-Metadata-project-ref":["project_ref"],"X-SES-Message-Tags":["ses:feedback-id-a=project_ref,ses:feedback-id-b=$messageType"]}`) + os.Setenv("GOTRUE_MAILER_EMAIL_VALIDATION_SERVICE_HEADERS", `{"apikey":["test"]}`) + os.Setenv("GOTRUE_SMTP_LOGGING_ENABLED", "true") + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, true, gc.SMTP.LoggingEnabled) + assert.Equal(t, "project_ref", gc.SMTP.NormalizedHeaders()["X-PM-Metadata-project-ref"][0]) + require.NotNil(t, gc) + assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) + assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFAVerificationAttempt.URI) + + { + os.Setenv("GOTRUE_RATE_LIMIT_EMAIL_SENT", "0/1h") + + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, float64(0), gc.RateLimitEmailSent.Events) + assert.Equal(t, time.Hour, gc.RateLimitEmailSent.OverTime) + } + + { + os.Setenv("GOTRUE_RATE_LIMIT_EMAIL_SENT", "10/1h") + + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, float64(10), gc.RateLimitEmailSent.Events) + assert.Equal(t, time.Hour, gc.RateLimitEmailSent.OverTime) + } + + { + hdrs := gc.Mailer.GetEmailValidationServiceHeaders() + assert.Equal(t, 1, len(hdrs["apikey"])) + assert.Equal(t, "test", hdrs["apikey"][0]) + } + + { + cfg, err := LoadGlobalFromEnv() + require.NoError(t, err) + require.NotNil(t, cfg) + } + + { + cfg, err := LoadGlobal("") + require.NoError(t, err) + require.NotNil(t, cfg) + } + + { + cfg, err := LoadGlobal("__invalid__") + require.Error(t, err) + require.Nil(t, cfg) + } + + { + os.Setenv("GOTRUE_MAILER_AUTOCONFIRM", "TRUE") + os.Setenv("GOTRUE_MAILER_ALLOW_UNVERIFIED_EMAIL_SIGN_INS", "TRUE") + cfg, err := LoadGlobal("") + require.Error(t, err) + require.Nil(t, cfg) + os.Setenv("GOTRUE_MAILER_AUTOCONFIRM", "FALSE") + os.Setenv("GOTRUE_MAILER_ALLOW_UNVERIFIED_EMAIL_SIGN_INS", "FALSE") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + err := loadGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + PasswordVerificationAttempt: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + SendSMS: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + SendEmail: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + CustomAccessToken: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.SAML = SAMLConfiguration{ + Enabled: true, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + cfg := new(GlobalConfiguration) + cfg.Sms.Provider = "invalid" + + err := populateGlobal(cfg) + require.NoError(t, err) + } + + { + cfg := new(GlobalConfiguration) + cfg.Sms.Provider = "invalid" + cfg.Sms.Template = "{{{{{{{{{}}}}}}}}}" + + err := populateGlobal(cfg) + require.Error(t, err) + } + + { + cfg := new(GlobalConfiguration) + cfg.MFA.Phone.EnrollEnabled = true + cfg.MFA.Phone.Template = "{{{{{{{{{}}}}}}}}}" + + err := populateGlobal(cfg) + require.Error(t, err) + } + + { + cfg := new(GlobalConfiguration) + cfg.MFA.Phone.EnrollEnabled = true + + err := populateGlobal(cfg) + require.NoError(t, err) + } +} + +func TestPasswordRequiredCharactersDecode(t *testing.T) { + examples := []struct { + Value string + Result []string + }{ + { + Value: "a:b:c", + Result: []string{ + "a", + "b", + "c", + }, + }, + { + Value: "a\\:b:c", + Result: []string{ + "a:b", + "c", + }, + }, + { + Value: "a:b\\:c", + Result: []string{ + "a", + "b:c", + }, + }, + { + Value: "\\:a:b:c", + Result: []string{ + ":a", + "b", + "c", + }, + }, + { + Value: "a:b:c\\:", + Result: []string{ + "a", + "b", + "c:", + }, + }, + { + Value: "::\\::", + Result: []string{ + ":", + }, + }, + { + Value: "", + Result: nil, + }, + { + Value: " ", + Result: []string{ + " ", + }, + }, + } + + for i, example := range examples { + var into PasswordRequiredCharacters + require.NoError(t, into.Decode(example.Value), "Example %d failed with error", i) + + require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) + } +} + +func TestHTTPHookSecretsDecode(t *testing.T) { + examples := []struct { + Value string + Result []string + }{ + { + Value: "v1,whsec_secret1|v1a,whpk_secrets:whsk_secret2|v1,whsec_secret3", + Result: []string{"v1,whsec_secret1", "v1a,whpk_secrets:whsk_secret2", "v1,whsec_secret3"}, + }, + { + Value: "v1,whsec_singlesecret", + Result: []string{"v1,whsec_singlesecret"}, + }, + { + Value: " ", + Result: []string{" "}, + }, + { + Value: "", + Result: nil, + }, + { + Value: "|a|b|c", + Result: []string{ + "a", + "b", + "c", + }, + }, + { + Value: "||||", + Result: nil, + }, + { + Value: "::", + Result: []string{"::"}, + }, + { + Value: "secret1::secret3", + Result: []string{"secret1::secret3"}, + }, + } + + for i, example := range examples { + var into HTTPHookSecrets + + require.NoError(t, into.Decode(example.Value), "Example %d failed with error", i) + require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) + } +} + +func TestValidateExtensibilityPointURI(t *testing.T) { + cases := []struct { + desc string + uri string + expectError bool + }{ + // Positive test cases + {desc: "Valid HTTPS URI", uri: "https://asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender", expectError: false}, + {desc: "Valid HTTPS URI", uri: "HTTPS://www.asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender", expectError: false}, + {desc: "Valid Postgres URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectError: false}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectError: false}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectError: false}, + {desc: "Valid HTTP URI", uri: "http://localhost/functions/v1/custom-sms-sender", expectError: false}, + + // Negative test cases + {desc: "Invalid HTTP URI", uri: "http://asdfgggg.website.co/functions/v1/custom-sms-sender", expectError: true}, + {desc: "Invalid HTTPS URI (HTTP)", uri: "http://asdfgggqqwwerty.supabase.co/functions/v1/custom-sms-sender", expectError: true}, + {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectError: true}, + {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectError: true}, + {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectError: true}, + } + + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: tc.uri} + err := ep.ValidateExtensibilityPoint() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} + +func TestValidateExtensibilityPointSecrets(t *testing.T) { + validHTTPSURI := "https://asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender" + cases := []struct { + desc string + secret []string + expectError bool + }{ + // Positive test cases + {desc: "Valid Symmetric Secret", secret: []string{"v1,whsec_NDYzODhlNTY0ZGI1OWZjYTU2NjMwN2FhYzM3YzBkMWQ0NzVjNWRkNTJmZDU0MGNhYTAzMjVjNjQzMzE3Mjk2Zg====="}, expectError: false}, + {desc: "Valid Asymmetric Secret", secret: []string{"v1a,whpk_NDYzODhlNTY0ZGI1OWZjYTU2NjMwN2FhYzM3YzBkMWQ0NzVjNWRkNTJmZDU0MGNhYTAzMjVjNjQzMzE3Mjk2Zg==:whsk_abc889a6b1160015025064f108a48d6aba1c7c95fa8e304b4d225e8ae0121511"}, expectError: false}, + {desc: "Valid Mix of Symmetric and asymmetric Secret", secret: []string{"v1,whsec_2b49264c90fd15db3bb0e05f4e1547b9c183eb06d585be8a", "v1a,whpk_46388e564db59fca566307aac37c0d1d475c5dd52fd540caa0325c643317296f:whsk_YWJjODg5YTZiMTE2MDAxNTAyNTA2NGYxMDhhNDhkNmFiYTFjN2M5NWZhOGUzMDRiNGQyMjVlOGFlMDEyMTUxMSI="}, expectError: false}, + + // Negative test cases + {desc: "Invalid Asymmetric Secret", secret: []string{"v1a,john:jill", "jill"}, expectError: true}, + {desc: "Invalid Symmetric Secret", secret: []string{"tommy"}, expectError: true}, + } + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: validHTTPSURI, HTTPHookSecrets: tc.secret} + err := ep.ValidateExtensibilityPoint() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + } + +} + +func TestTime(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2025-01-01T10:00:00.00Z") + + cases := []struct { + txt string + exp time.Time + err string + }{ + + // valid + { + txt: now.Format(time.RFC3339), + exp: now, + }, + + // trimmed + { + txt: " " + now.Format(time.RFC3339) + "\n \r", + exp: now, + }, + + // len < 1 + { + txt: "", + exp: time.Time{}, + }, + + // invalid time format + { + txt: "invalid", + exp: time.Time{}, + err: `"invalid" as "2006-01-02T15:04:05Z07:00":` + + ` cannot parse "invalid" as "2006"`, + }, + } + + for idx, tc := range cases { + t.Logf("test #%v - exp err %v with time %v from UnmarshalText(%q)", + idx, tc.err, tc.exp, tc.txt) + + var v Time + err := v.UnmarshalText([]byte(tc.txt)) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + require.Equal(t, tc.exp, v.Time) + } +} + +func TestValidate(t *testing.T) { + type testCase struct { + val interface{ Validate() error } + check func(t *testing.T, v any) + err string + } + cases := []testCase{ + { + val: &APIConfiguration{ExternalURL: "http://localhost"}, + }, + { + val: &APIConfiguration{ExternalURL: "invalid"}, + err: `parse "invalid": invalid URI for request`, + }, + + { + val: &APIConfiguration{ExternalURL: "invalid"}, + err: `parse "invalid": invalid URI for request`, + }, + + { + val: &SessionsConfiguration{Timebox: nil}, + }, + { + val: &SessionsConfiguration{Timebox: new(time.Duration)}, + err: `conf: session timebox duration must` + + ` be positive when set, was 0`, + }, + { + val: &SessionsConfiguration{Timebox: toPtr(time.Duration(-1))}, + err: `conf: session timebox duration must` + + ` be positive when set, was -1`, + }, + { + val: &SessionsConfiguration{Timebox: toPtr(time.Duration(1))}, + }, + + { + val: &SMTPConfiguration{}, + }, + { + val: &SMTPConfiguration{ + AdminEmail: "test@example.com", + SenderName: "Test", + }, + check: func(t *testing.T, v any) { + got := (v.(*SMTPConfiguration)).FromAddress() + require.Equal(t, `"Test" `, got) + }, + }, + { + val: &SMTPConfiguration{Headers: "invalid"}, + err: `conf: SMTP headers not a map[string][]string format:` + + ` invalid character 'i' looking for beginning of value`, + }, + + { + val: &MailerConfiguration{}, + }, + { + val: &MailerConfiguration{EmailValidationServiceHeaders: "invalid"}, + err: `conf: mailer validation headers not a map[string][]string format:` + + ` invalid character 'i' looking for beginning of value`, + }, + + { + val: &CaptchaConfiguration{Enabled: false}, + }, + { + val: &CaptchaConfiguration{Enabled: true}, + err: "unsupported captcha provider:", + }, + { + val: &CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "", + }, + err: "captcha provider secret is empty", + }, + { + val: &CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "abc", + }, + }, + + { + val: &DatabaseEncryptionConfiguration{Encrypt: false}, + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + EncryptionKeyID: "", + }, + err: "conf: encryption key ID must be specified", + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + EncryptionKeyID: "keyid", + EncryptionKey: "|", + }, + err: "illegal base64 data at input byte 0", + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + EncryptionKeyID: "keyid", + EncryptionKey: "aaaaaaa", + }, + err: "conf: encryption key is not 256 bits", + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + err: "conf: encryption key must also be present in decryption keys", + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + DecryptionKeys: map[string]string{ + "keyid": "|", + }, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + err: "illegal base64 data at input byte 0", + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + DecryptionKeys: map[string]string{ + "keyid": "aaa", + }, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + err: `conf: decryption key with ID "keyid" must be 256 bits`, + }, + { + val: &DatabaseEncryptionConfiguration{ + Encrypt: true, + DecryptionKeys: map[string]string{ + "keyid": base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + }, + + { + val: &SecurityConfiguration{ + Captcha: CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "abc", + }, + DBEncryption: DatabaseEncryptionConfiguration{ + Encrypt: true, + DecryptionKeys: map[string]string{ + "keyid": base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + }, + }, + { + val: &SecurityConfiguration{ + Captcha: CaptchaConfiguration{ + Enabled: true, + }, + DBEncryption: DatabaseEncryptionConfiguration{ + Encrypt: true, + DecryptionKeys: map[string]string{ + "keyid": base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + EncryptionKeyID: "keyid", + EncryptionKey: base64.RawURLEncoding.EncodeToString( + []byte(strings.Repeat("a", 32)), + ), + }, + }, + err: `unsupported captcha provider:`, + }, + { + val: &SecurityConfiguration{ + Captcha: CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "abc", + }, + DBEncryption: DatabaseEncryptionConfiguration{ + Encrypt: true, + }, + }, + err: `conf: encryption key ID must be specified`, + }, + + { + val: &TwilioProviderConfiguration{}, + err: `missing Twilio account SID`, + }, + { + val: &TwilioProviderConfiguration{ + AccountSid: "a", + }, + err: `missing Twilio auth token`, + }, + { + val: &TwilioProviderConfiguration{ + AccountSid: "a", + AuthToken: "b", + }, + err: `missing Twilio message service SID or Twilio phone number`, + }, + { + val: &TwilioProviderConfiguration{ + AccountSid: "a", + AuthToken: "b", + MessageServiceSid: "c", + }, + }, + + { + val: &GlobalConfiguration{}, + err: `parse "": empty url`, + }, + + { + val: &TwilioVerifyProviderConfiguration{}, + err: `missing Twilio account SID`, + }, + { + val: &TwilioVerifyProviderConfiguration{ + AccountSid: "a", + }, + err: `missing Twilio auth token`, + }, + { + val: &TwilioVerifyProviderConfiguration{ + AccountSid: "a", + AuthToken: "b", + }, + err: `missing Twilio message service SID or Twilio phone number`, + }, + { + val: &TwilioVerifyProviderConfiguration{ + AccountSid: "a", + AuthToken: "b", + MessageServiceSid: "c", + }, + }, + + { + val: &MessagebirdProviderConfiguration{}, + err: `missing Messagebird access key`, + }, + { + val: &MessagebirdProviderConfiguration{ + AccessKey: "a", + }, + err: `missing Messagebird originator`, + }, + { + val: &MessagebirdProviderConfiguration{ + AccessKey: "a", + Originator: "b", + }, + }, + + { + val: &TextlocalProviderConfiguration{}, + err: `missing Textlocal API key`, + }, + { + val: &TextlocalProviderConfiguration{ + ApiKey: "a", + }, + err: `missing Textlocal sender`, + }, + { + val: &TextlocalProviderConfiguration{ + ApiKey: "a", + Sender: "b", + }, + }, + + { + val: &VonageProviderConfiguration{}, + err: `missing Vonage API key`, + }, + { + val: &VonageProviderConfiguration{ + ApiKey: "a", + }, + err: `missing Vonage API secret`, + }, + { + val: &VonageProviderConfiguration{ + ApiKey: "a", + ApiSecret: "b", + }, + err: `missing Vonage 'from' parameter`, + }, + { + val: &VonageProviderConfiguration{ + ApiKey: "a", + ApiSecret: "b", + From: "c", + }, + }, + + { + val: &HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + URI: "|", + }, + }, + err: `only postgres hooks and HTTPS functions are supported at the moment`, + }, + { + val: &HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + URI: "http://localhost/foo", + }, + }, + }, + { + val: &HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + URI: "\n", + }, + }, + err: `net/url: invalid control character in URL`, + check: func(t *testing.T, v any) { + hcfg := (v.(*HookConfiguration)) + err := hcfg.MFAVerificationAttempt.PopulateExtensibilityPoint() + require.Error(t, err) + }, + }, + { + val: &HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + URI: "http://localhost/foo", + }, + }, + check: func(t *testing.T, v any) { + hcfg := (v.(*HookConfiguration)) + err := hcfg.MFAVerificationAttempt.PopulateExtensibilityPoint() + require.NoError(t, err) + }, + }, + { + val: &HookConfiguration{ + MFAVerificationAttempt: ExtensibilityPointConfiguration{ + URI: "pg-functions://foo/bar/baz", + }, + }, + check: func(t *testing.T, v any) { + hcfg := (v.(*HookConfiguration)) + err := hcfg.MFAVerificationAttempt.PopulateExtensibilityPoint() + require.NoError(t, err) + }, + }, + } + + for idx, tc := range cases { + t.Logf("test #%v - exp err %v from %T.Validate() (%#[3]v)", + idx, tc.err, tc.val) + + err := tc.val.Validate() + if tc.check != nil { + tc.check(t, tc.val) + } + + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + } +} + +func TestMethods(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2025-01-01T10:00:00.00Z") + + { + val := &CORSConfiguration{ + AllowedHeaders: []string{ + "X-Test1", + }, + } + + got := val.AllAllowedHeaders(nil) + sort.Strings(got) + require.Equal(t, []string{"X-Test1"}, got) + + got = val.AllAllowedHeaders([]string{"X-Test2"}) + sort.Strings(got) + require.Equal(t, []string{"X-Test1", "X-Test2"}, got) + + val.AllowedHeaders = nil + sort.Strings(got) + got = val.AllAllowedHeaders([]string{"X-Test2"}) + require.Equal(t, []string{"X-Test2"}, got) + + val.AllowedHeaders = nil + got = val.AllAllowedHeaders(nil) + require.Equal(t, ([]string)(nil), got) + } + + { + val := &SmsProviderConfiguration{} + ok := val.IsTwilioVerifyProvider() + require.False(t, ok) + + val.Provider = "twilio_verify" + ok = val.IsTwilioVerifyProvider() + require.True(t, ok) + + // invalid otp (TestOTP map == nil) + got, ok := val.GetTestOTP("13338888", now) + require.False(t, ok) + require.Equal(t, "", got) + + // valid + val.TestOTP = map[string]string{"13334444": "123456"} + got, ok = val.GetTestOTP("13334444", now) + require.True(t, ok) + require.Equal(t, "123456", got) + + // invalid otp (not in non-nil TestOTP map) + got, ok = val.GetTestOTP("13338888", now) + require.False(t, ok) + require.Equal(t, "", got) + + // valid otp with non-zero time + val.TestOTPValidUntil = Time{Time: now.Add(time.Second)} + got, ok = val.GetTestOTP("13334444", now) + require.True(t, ok) + require.Equal(t, "123456", got) + + // invalid otp (expired) + val.TestOTPValidUntil = Time{Time: now.Add(time.Second)} + got, ok = val.GetTestOTP("13338888", now.Add(time.Second*2)) + require.False(t, ok) + require.Equal(t, "", got) + } + + { + val := &OAuthProviderConfiguration{} + + err := val.ValidateOAuth() + require.Error(t, err) + require.Contains(t, err.Error(), "provider is not enabled") + + val.Enabled = true + err = val.ValidateOAuth() + require.Error(t, err) + require.Contains(t, err.Error(), "missing OAuth client ID") + + val.ClientID = []string{"a"} + err = val.ValidateOAuth() + require.Error(t, err) + require.Contains(t, err.Error(), "missing OAuth secret") + + val.Secret = "a" + err = val.ValidateOAuth() + require.Error(t, err) + require.Contains(t, err.Error(), "missing redirect URI") + + val.RedirectURI = "a" + err = val.ValidateOAuth() + require.NoError(t, err) + } + + { + val := &GlobalConfiguration{} + err := val.ApplyDefaults() + require.Error(t, err) + require.Contains(t, err.Error(), + `failed to initialize *jwk.symmetricKey from []uint8:`+ + ` non-empty []byte key required`) + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + }, + } + err := val.ApplyDefaults() + require.NoError(t, err) + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + KeyID: "a", + }, + } + + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = 0 + + err := val.applyDefaultsJWTPrivateKey(key) + require.Error(t, err) + require.Contains(t, err.Error(), "sentinel") + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + KeyID: "a", + }, + } + + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = -1 + key.alg = jwa.SignatureAlgorithm("") + + err := val.applyDefaultsJWTPrivateKey(key) + require.Error(t, err) + require.Contains(t, err.Error(), "sentinel") + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + KeyID: "a", + }, + } + + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = -2 + key.alg = jwa.SignatureAlgorithm("") + + err := val.applyDefaultsJWTPrivateKey(key) + require.Error(t, err) + require.Contains(t, err.Error(), "sentinel") + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + KeyID: "a", + }, + } + + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = -3 + key.ops = jwk.KeyOperationList{} + key.alg = jwa.SignatureAlgorithm("") + + err := val.applyDefaultsJWTPrivateKey(key) + require.Error(t, err) + require.Contains(t, err.Error(), "sentinel") + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + KeyID: "a", + }, + } + + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = -4 + key.alg = jwa.SignatureAlgorithm("") + + err := val.applyDefaultsJWTPrivateKey(key) + require.Error(t, err) + require.Contains(t, err.Error(), "sentinel") + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + }, + Mailer: MailerConfiguration{ + Autoconfirm: true, + AllowUnverifiedEmailSignIns: true, + }, + } + err := val.ApplyDefaults() + require.Error(t, err) + require.Contains(t, err.Error(), `cannot enable both `+ + `GOTRUE_MAILER_AUTOCONFIRM and `+ + `GOTRUE_MAILER_ALLOW_UNVERIFIED_EMAIL_SIGN_INS`) + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + }, + Sms: SmsProviderConfiguration{ + TestOTP: map[string]string{"13334444": "123456"}, + }, + } + err := val.ApplyDefaults() + require.NoError(t, err) + } + { + val := &GlobalConfiguration{ + JWT: JWTConfiguration{ + Secret: "a", + }, + URIAllowList: []string{ + "http://localhost/*/**", + }, + } + err := val.ApplyDefaults() + require.NoError(t, err) + } +} + +func TestLoading(t *testing.T) { + defer os.Clearenv() + + { + os.Clearenv() + err := LoadFile("abc") + require.Error(t, err) + } + + { + os.Clearenv() + err := LoadFile("") + require.NoError(t, err) + } + + { + os.Clearenv() + err := loadEnvironment("abc") + require.Error(t, err) + } + + { + os.Clearenv() + err := loadEnvironment("") + require.NoError(t, err) + } + + { + os.Clearenv() + err := LoadDirectory("") + require.NoError(t, err) + } + + { + os.Clearenv() + err := LoadDirectory("__invalid__") + require.Error(t, err) + require.Contains(t, err.Error(), + `open __invalid__: no such file or directory`) + } + + { + os.Clearenv() + err := LoadDirectory("../reloader/testdata") + require.NoError(t, err) + } + + { + os.Clearenv() + err := loadDirectoryPaths("__invalid__") + require.Error(t, err) + } + + { + os.Clearenv() + cfg, err := LoadGlobalFromEnv() + require.Error(t, err) + require.Nil(t, cfg) + } +} + +func toPtr[T any](v T) *T { + return &(&([1]T{T(v)}))[0] +} diff --git a/internal/conf/jwk.go b/internal/conf/jwk.go new file mode 100644 index 000000000..0bdf94d9a --- /dev/null +++ b/internal/conf/jwk.go @@ -0,0 +1,172 @@ +package conf + +import ( + "encoding/json" + "fmt" + + "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwk" +) + +type JwtKeysDecoder map[string]JwkInfo + +type JwkInfo struct { + PublicKey jwk.Key `json:"public_key"` + PrivateKey jwk.Key `json:"private_key"` +} + +// Decode implements the Decoder interface +func (j *JwtKeysDecoder) Decode(value string) error { + data := make([]json.RawMessage, 0) + if err := json.Unmarshal([]byte(value), &data); err != nil { + return err + } + + config := JwtKeysDecoder{} + for _, key := range data { + if err := j.decodeKey(config, key); err != nil { + return err + } + } + + *j = config + return nil +} + +func (j *JwtKeysDecoder) decodeKey(config JwtKeysDecoder, key []byte) error { + privJwk, err := jwk.ParseKey(key) + if err != nil { + return err + } + return j.decodePrivateKey(config, privJwk) +} + +func (j *JwtKeysDecoder) decodePrivateKey( + config JwtKeysDecoder, + privJwk jwk.Key, +) error { + pubJwk, err := jwk.PublicKeyOf(privJwk) + if err != nil { + return err + } + return j.decodePublicKey(config, privJwk, pubJwk) +} + +func (j *JwtKeysDecoder) decodePublicKey( + config JwtKeysDecoder, + privJwk jwk.Key, + pubJwk jwk.Key, +) error { + // all public keys should have the the use claim set to 'sig + if err := pubJwk.Set(jwk.KeyUsageKey, "sig"); err != nil { + return err + } + + // all public keys should only have 'verify' set as the key_ops + if err := pubJwk.Set(jwk.KeyOpsKey, jwk.KeyOperationList{jwk.KeyOpVerify}); err != nil { + return err + } + + config[pubJwk.KeyID()] = JwkInfo{ + PublicKey: pubJwk, + PrivateKey: privJwk, + } + return nil +} + +func (j *JwtKeysDecoder) Validate() error { + // Validate performs _minimal_ checks if the data stored in the key are valid. + // By minimal, we mean that it does not check if the key is valid for use in + // cryptographic operations. For example, it does not check if an RSA key's + // `e` field is a valid exponent, or if the `n` field is a valid modulus. + // Instead, it checks for things such as the _presence_ of some required fields, + // or if certain keys' values are of particular length. + // + // Note that depending on the underlying key type, use of this method requires + // that multiple fields in the key are properly populated. For example, an EC + // key's "x", "y" fields cannot be validated unless the "crv" field is populated first. + signingKeys := []jwk.Key{} + for _, key := range *j { + if err := key.PrivateKey.Validate(); err != nil { + return err + } + // symmetric keys don't have public keys + if key.PublicKey != nil { + if err := key.PublicKey.Validate(); err != nil { + return err + } + } + + for _, op := range key.PrivateKey.KeyOps() { + if op == jwk.KeyOpSign { + signingKeys = append(signingKeys, key.PrivateKey) + break + } + } + } + + switch { + case len(signingKeys) == 0: + return fmt.Errorf("no signing key detected") + case len(signingKeys) > 1: + return fmt.Errorf("multiple signing keys detected, only 1 signing key is supported") + } + + return nil +} + +func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) { + for _, key := range config.Keys { + for _, op := range key.PrivateKey.KeyOps() { + // the private JWK with key_ops "sign" should be used as the signing key + if op == jwk.KeyOpSign { + return key.PrivateKey, nil + } + } + } + return nil, fmt.Errorf("no signing key found") +} + +func GetSigningKey(k jwk.Key) (any, error) { + var key any + if err := k.Raw(&key); err != nil { + return nil, err + } + return key, nil +} + +func GetSigningAlg(k jwk.Key) jwt.SigningMethod { + if k == nil { + return jwt.SigningMethodHS256 + } + + switch (k).Algorithm().String() { + case "RS256": + return jwt.SigningMethodRS256 + case "RS512": + return jwt.SigningMethodRS512 + case "ES256": + return jwt.SigningMethodES256 + case "ES512": + return jwt.SigningMethodES512 + case "EdDSA": + return jwt.SigningMethodEdDSA + } + + // return HS256 to preserve existing behaviour + return jwt.SigningMethodHS256 +} + +func FindPublicKeyByKid(kid string, config *JWTConfiguration) (any, error) { + if k, ok := config.Keys[kid]; ok { + key, err := GetSigningKey(k.PublicKey) + if err != nil { + return nil, err + } + return key, nil + } + if kid == config.KeyID { + return []byte(config.Secret), nil + } + return nil, fmt.Errorf("invalid kid: %s", kid) +} diff --git a/internal/conf/jwk_test.go b/internal/conf/jwk_test.go new file mode 100644 index 000000000..a086e5578 --- /dev/null +++ b/internal/conf/jwk_test.go @@ -0,0 +1,449 @@ +package conf + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/require" +) + +func TestJwtKeys(t *testing.T) { + + // JwtKeysDecoder - Decode - unmarshal error + { + dec := make(JwtKeysDecoder) + jwt := "{" + err := dec.Decode(jwt) + require.Error(t, err) + } + + // JwtKeysDecoder - Decode - error calling decodeKey on index 0 + { + dec := make(JwtKeysDecoder) + jwt := "[{}]" + err := dec.Decode(jwt) + require.Error(t, err) + } + + // JwtKeysDecoder - Decode - ParseKey error + { + dec := make(JwtKeysDecoder) + dst := make(JwtKeysDecoder) + jwt := "{}" + err := dec.decodeKey(dst, []byte(jwt)) + require.Error(t, err) + } + + // JwtKeysDecoder - Decode - PublicKeyOf error + { + dec := make(JwtKeysDecoder) + dst := make(JwtKeysDecoder) + err := dec.decodePrivateKey(dst, nil) + require.Error(t, err) + } + + // JwtKeysDecoder - Decode - Set jwt.KeyUsageKey error + { + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = 0 + + dec := make(JwtKeysDecoder) + dst := make(JwtKeysDecoder) + err := dec.decodePublicKey(dst, key, key) + require.Error(t, err) + } + + // JwtKeysDecoder - Decode - Set jwt.KeyOpsKey error + { + sentinel := errors.New("sentinel") + key := helpToMockKey(t, sentinel) + key.n = -1 + + dec := make(JwtKeysDecoder) + dst := make(JwtKeysDecoder) + err := dec.decodePublicKey(dst, key, key) + require.Error(t, err) + } + + // JwtKeysDecoder - Validate - signing keys == 0 + { + dec := make(JwtKeysDecoder) + err := dec.Validate() + require.Error(t, err) + } + + // JwtKeysDecoder - Validate - signing keys > 1 + { + dec := make(JwtKeysDecoder) + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + sigKey1, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + pubJwk, err := jwk.PublicKeyOf(sigKey1) + require.NoError(t, err) + + dec["sig1"] = JwkInfo{ + PublicKey: pubJwk, + PrivateKey: sigKey1, + } + dec["sig2"] = JwkInfo{ + PublicKey: pubJwk, + PrivateKey: sigKey1, + } + + err = dec.Validate() + require.Error(t, err) + } + + // JwtKeysDecoder - Validate - PrivateKey.Validate() error + { + dec := make(JwtKeysDecoder) + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + privKey, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + sentinel := errors.New("sentinel") + key := &mockKey{Key: privKey, err: sentinel, n: 0} + + dec["sig1"] = JwkInfo{ + PublicKey: key, + PrivateKey: key, + } + + err = dec.Validate() + require.Error(t, err) + } + + // JwtKeysDecoder - Validate - PublicKey.Validate() error + { + dec := make(JwtKeysDecoder) + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + privKey, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + sentinel := errors.New("sentinel") + key := &mockKey{Key: privKey, err: sentinel, n: -1} + + dec["sig1"] = JwkInfo{ + PublicKey: key, + PrivateKey: key, + } + + err = dec.Validate() + require.Error(t, err) + } + + // GetSigningJwk - valid + { + dec := make(JwtKeysDecoder) + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + sigKey1, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + pubJwk, err := jwk.PublicKeyOf(sigKey1) + require.NoError(t, err) + + info := JwkInfo{ + PublicKey: pubJwk, + PrivateKey: sigKey1, + } + dec["sig1"] = info + + jwtConfig := &JWTConfiguration{ + Keys: dec, + } + got, err := GetSigningJwk(jwtConfig) + require.NoError(t, err) + require.Equal(t, sigKey1, got) + } + + // GetSigningJwk - not found + { + dec := make(JwtKeysDecoder) + jwtConfig := &JWTConfiguration{ + Keys: dec, + } + got, err := GetSigningJwk(jwtConfig) + require.Nil(t, got) + require.Error(t, err) + require.Equal(t, "no signing key found", err.Error()) + } + + // GetSigningKey - valid + { + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + sigKey1, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + got, err := GetSigningKey(sigKey1) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, fmt.Sprintf("%T", got), "*ecdsa.PrivateKey") + } + + // GetSigningKey - not found + { + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["4"]) + require.NoError(t, err) + + privKey, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + sentinel := errors.New("sentinel") + key := &mockKey{Key: privKey, err: sentinel, n: 0} + + got, err := GetSigningKey(key) + require.Nil(t, got) + require.Error(t, err) + require.Equal(t, sentinel, err) + } + + // FindPublicKeyByKid - valid + { + dec := make(JwtKeysDecoder) + jwtConfig := &JWTConfiguration{ + Keys: dec, + KeyID: "abc", + Secret: "sentinel", + } + got, err := FindPublicKeyByKid("abc", jwtConfig) + require.NoError(t, err) + require.Equal(t, []byte("sentinel"), got) + } + + // FindPublicKeyByKid - not found + { + dec := make(JwtKeysDecoder) + jwtConfig := &JWTConfiguration{ + Keys: dec, + } + got, err := FindPublicKeyByKid("abc", jwtConfig) + require.Nil(t, got) + require.Error(t, err) + require.Equal(t, "invalid kid: abc", err.Error()) + } + + // FindPublicKeyByKid - GetSigningKey success + { + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["3"]) + require.NoError(t, err) + + sigKey1, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + dec := make(JwtKeysDecoder) + dec["abc"] = JwkInfo{ + PublicKey: sigKey1, + PrivateKey: sigKey1, + } + + jwtConfig := &JWTConfiguration{ + Keys: dec, + } + + got, err := FindPublicKeyByKid("abc", jwtConfig) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, fmt.Sprintf("%T", got), "*ecdsa.PrivateKey") + } + + // FindPublicKeyByKid - GetSigningKey fails + { + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["4"]) + require.NoError(t, err) + + privKey, err := jwk.ParseKey(jwt) + require.NoError(t, err) + + sentinel := errors.New("sentinel") + key := &mockKey{Key: privKey, err: sentinel, n: 0} + + dec := make(JwtKeysDecoder) + dec["abc"] = JwkInfo{ + PublicKey: key, + PrivateKey: key, + } + + jwtConfig := &JWTConfiguration{ + Keys: dec, + } + + got, err := FindPublicKeyByKid("abc", jwtConfig) + require.Nil(t, got) + require.Error(t, err) + require.Equal(t, sentinel, err) + } + + // GetSigningAlg - nil key + { + alg := GetSigningAlg(nil) + require.Equal(t, alg, jwt.SigningMethodHS256) + } + + // GetSigningAlg - nil key + { + algs := []jwa.KeyAlgorithm{ + jwa.RS256, + jwa.RS512, + jwa.ES256, + jwa.ES512, + jwa.EdDSA, + jwa.HS256, + } + for _, alg := range algs { + key := &mockKey{alg: alg} + got := GetSigningAlg(key) + require.Equal(t, alg.String(), got.Alg()) + } + } +} + +func TestDecode(t *testing.T) { + // array of JWKs containing 4 keys + gotrueJwtKeys := testJwtKey + var decoder JwtKeysDecoder + require.NoError(t, decoder.Decode(gotrueJwtKeys)) + require.Len(t, decoder, 4) + + for kid, key := range decoder { + require.NotEmpty(t, kid) + require.NotNil(t, key.PrivateKey) + require.NotNil(t, key.PublicKey) + require.NotEmpty(t, key.PublicKey.KeyOps(), "missing key_ops claim") + } +} + +func TestJWTConfiguration(t *testing.T) { + // array of JWKs containing 4 keys + gotrueJwtKeys := testJwtKey + var decoder JwtKeysDecoder + require.NoError(t, decoder.Decode(gotrueJwtKeys)) + require.Len(t, decoder, 4) + + cases := []struct { + desc string + config JWTConfiguration + expectedLength int + }{ + { + desc: "GOTRUE_JWT_KEYS is nil", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + }, + expectedLength: 1, + }, + { + desc: "GOTRUE_JWT_KEYS is an empty map", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + Keys: JwtKeysDecoder{}, + }, + expectedLength: 1, + }, + { + desc: "Prefer GOTRUE_JWT_KEYS over GOTRUE_JWT_SECRET", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + Keys: decoder, + }, + expectedLength: 4, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + globalConfig := GlobalConfiguration{ + JWT: c.config, + } + require.NoError(t, globalConfig.ApplyDefaults()) + require.NotEmpty(t, globalConfig.JWT.Keys) + require.Len(t, globalConfig.JWT.Keys, c.expectedLength) + for _, key := range globalConfig.JWT.Keys { + // public keys should contain these require claims + require.NotNil(t, key.PublicKey.Algorithm()) + require.NotNil(t, key.PublicKey.KeyID()) + require.NotNil(t, key.PublicKey.KeyOps()) + require.Equal(t, "sig", key.PublicKey.KeyUsage()) + } + }) + } +} + +func helpToMap(t *testing.T, str string) map[string]map[string]interface{} { + out := make(map[string]map[string]interface{}) + + var dst []map[string]interface{} + err := json.Unmarshal([]byte(str), &dst) + require.NoError(t, err) + for _, v := range dst { + out[v["kid"].(string)] = v + } + return out +} + +func helpToMockKey(t *testing.T, mockErr error) *mockKey { + m := helpToMap(t, testJwtKey) + jwt, err := json.Marshal(m["2"]) + require.NoError(t, err) + + privJwk, err := jwk.ParseKey(jwt) + require.NoError(t, err) + return &mockKey{Key: privJwk, err: mockErr} +} + +type mockKey struct { + jwk.Key + n int + err error + alg jwa.KeyAlgorithm + ops jwk.KeyOperationList +} + +func (o *mockKey) maybeErr() error { + if o.n == 0 { + return o.err + } + o.n++ + return nil +} + +func (o *mockKey) KeyOps() jwk.KeyOperationList { return o.ops } +func (o *mockKey) Algorithm() jwa.KeyAlgorithm { return o.alg } +func (o *mockKey) Validate() error { return o.maybeErr() } +func (o *mockKey) Raw(v any) error { return o.maybeErr() } +func (o *mockKey) Set(string, interface{}) error { return o.maybeErr() } +func (o *mockKey) PublicKey() (jwk.Key, error) { + if err := o.maybeErr(); err != nil { + return nil, err + } + return o.Key.PublicKey() +} + +const testJwtKey = `[{"kty":"oct","k":"9Sj51i2YvfY85NJZFD6rAl9fKDxSKjFgW6W6ZXOJLnU","kid":"1","key_ops":["verify"],"alg":"HS256"},{"kty":"RSA","n":"4slQjr-XoU6I1KXFWOeeJi387RIUxjhyzXX3GUVNb75a0SPKoGShlJEbpvuXqkDLGDweLcIZy-01nqgjSzMY_tUO3L78MxVfIVn7MByJ4_zbrVf5rjKeAk9EEMl6pb8nKJGArph9sOwL68LLioNySt_WNo_hMfuxUuVkRagh5gLjYoQ4odkULQrgwlMcXxXNnvg0aYURUr2SDmncHNuZQ3adebRlI164mUZPPWui2fg72R7c9qhVaAEzbdG-JAuC3zn5iL4zZk-8pOwZkM7Qb_2lrcXwdTl_Qz6fMdAHz_3rggac5oeKkdvO2x7_XiUwGxIBYSghxg5BBxcyqd6WrQ","e":"AQAB","d":"FjJo7uH4aUoktO8kHhbHbY_KSdQpHDjKyc7yTS_0DWYgUfdozzubJfRDF42vI-KsXssF-NoB0wJf0uP0L8ip6G326XPuoMQRTMgcaF8j6swTwsapSOEagr7BzcECx1zpc2-ojhwbLHSvRutWDzPJkbrUccF8vRC6BsiAUG4Hapiumbot7JtJGwU8ZUhxico7_OEJ_MtkRrHByXgrOMnzNLrmViI9rzvtWOhVc8sNDzLogDDi01AP0j6WeBhbOpaZ_1BMLQ9IeeN5Iiy-7Qj-q4-8kBXIPXpYaKMFnDTmhB0GAVUFimF6ojhZNAJvV81VMHPjrEmmps0_qBfIlKAB","p":"9G7wBpiSJHAl-w47AWvW60v_hye50lte4Ep2P3KeRyinzgxtEMivzldoqirwdoyPCJWwU7nNsv7AjdXVoHFy3fJvJeV5mhArxb2zA36OS_Tr3CQXtB3OO-RFwVcG7AGO7XvA54PK28siXY2VvkG2Xn_ZrbVebJnHQprn7ddUIIE","q":"7YSaG2E_M9XpgUJ0izwKdfGew6Hz5utPUdwMWjqr81BjtLkUtQ3tGYWs2tdaRYUTK4mNFyR2MjLYnMK-F37rue4LSKitmEu2N6RD9TwzcqwiEL_vuQTC985iJ0hzUC58LcbhYtTLU3KqZXXUqaeBXEwQAWxK1NRf6rQRhOGk4C0","dp":"fOV-sfAdpI7FaW3RCp3euGYh0B6lXW4goXyKxUq8w2FrtOY2iH_zDP0u1tyP-BNENr-91Fo5V__BxfeAa7XsWqo4zuVdaDJhG24d3Wg6L2ebaOXsUrV0Hrg6SFs-hzMYpBI69FEsQ3idO65P2GJdXBX51T-6WsWMwmTCo44GR4E","dq":"O2DrJe0p38ualLYIbMaV1uaQyleyoggxzEU20VfZpPpz8rpScvEIVVkV3Z_48WhTYo8AtshmxCXyAT6uRzFzvQfFymRhAbHr2_01ABoMwp5F5eoWBCsskscFwsxaB7GXWdpefla0figscTED-WXm8SwS1Eg-bParBAIAXzgKAAE","qi":"Cezqw8ECfMmwnRXJuiG2A93lzhixHxXISvGC-qbWaRmCfetheSviZlM0_KxF6dsvrw_aNfIPa8rv1TbN-5F04v_RU1CD79QuluzXWLkZVhPXorkK5e8sUi_odzAJXOwHKQzal5ndInl4XYctDHQr8jXcFW5Un65FhPwdAC6-aek","kid":"2","key_ops":["verify"],"alg":"RS256"},{"kty":"EC","x":"GwbnH57MUhgL14dJfayyzuI6o2_mB_Pm8xIuauHXtQs","y":"cYqN0VAcv0BC9wrg3vNgHlKhGP8ZEedUC2A8jXpaGwA","crv":"P-256","d":"4STEXq7W4UY0piCGPueMaQqAAZ5jVRjjA_b1Hq7YgmM","kid":"3","key_ops":["sign","verify"],"alg":"ES256"},{"crv":"Ed25519","d":"T179kXSOJHE8CNbqaI2HNdG8r3YbSoKYxNRSzTkpEcY","x":"iDYagELzmD4z6uaW7eAZLuQ9fiUlnLqtrh7AfNbiNiI","kty":"OKP","kid":"4","key_ops":["verify"],"alg":"EdDSA"}]` diff --git a/internal/conf/logging.go b/internal/conf/logging.go new file mode 100644 index 000000000..d079006e3 --- /dev/null +++ b/internal/conf/logging.go @@ -0,0 +1,11 @@ +package conf + +type LoggingConfig struct { + Level string `mapstructure:"log_level" json:"log_level"` + File string `mapstructure:"log_file" json:"log_file"` + DisableColors bool `mapstructure:"disable_colors" split_words:"true" json:"disable_colors"` + QuoteEmptyFields bool `mapstructure:"quote_empty_fields" split_words:"true" json:"quote_empty_fields"` + TSFormat string `mapstructure:"ts_format" json:"ts_format"` + Fields map[string]interface{} `mapstructure:"fields" json:"fields"` + SQL string `mapstructure:"sql" json:"sql"` +} diff --git a/internal/conf/metrics.go b/internal/conf/metrics.go new file mode 100644 index 000000000..ac6f7ec2d --- /dev/null +++ b/internal/conf/metrics.go @@ -0,0 +1,26 @@ +package conf + +type MetricsExporter = string + +const ( + Prometheus MetricsExporter = "prometheus" + OpenTelemetryMetrics MetricsExporter = "opentelemetry" +) + +type MetricsConfig struct { + Enabled bool + + Exporter MetricsExporter `default:"opentelemetry"` + + // ExporterProtocol is the OTEL_EXPORTER_OTLP_PROTOCOL env variable, + // only available when exporter is opentelemetry. See: + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md + ExporterProtocol string `default:"http/protobuf" envconfig:"OTEL_EXPORTER_OTLP_PROTOCOL"` + + PrometheusListenHost string `default:"0.0.0.0" envconfig:"OTEL_EXPORTER_PROMETHEUS_HOST"` + PrometheusListenPort string `default:"9100" envconfig:"OTEL_EXPORTER_PROMETHEUS_PORT"` +} + +func (mc MetricsConfig) Validate() error { + return nil +} diff --git a/internal/conf/profiler.go b/internal/conf/profiler.go new file mode 100644 index 000000000..41752bf90 --- /dev/null +++ b/internal/conf/profiler.go @@ -0,0 +1,7 @@ +package conf + +type ProfilerConfig struct { + Enabled bool `default:"false"` + Host string `default:"localhost"` + Port string `default:"9998"` +} diff --git a/internal/conf/rate.go b/internal/conf/rate.go new file mode 100644 index 000000000..059ed65f0 --- /dev/null +++ b/internal/conf/rate.go @@ -0,0 +1,65 @@ +package conf + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +const defaultOverTime = time.Hour + +const ( + BurstRateType = "burst" + IntervalRateType = "interval" +) + +type Rate struct { + Events float64 `json:"events,omitempty"` + OverTime time.Duration `json:"over_time,omitempty"` + typ string +} + +func (r *Rate) GetRateType() string { + if r.typ == "" { + return IntervalRateType + } + return r.typ +} + +// Decode is used by envconfig to parse the env-config string to a Rate value. +func (r *Rate) Decode(value string) error { + if f, err := strconv.ParseFloat(value, 64); err == nil { + r.typ = IntervalRateType + r.Events = f + r.OverTime = defaultOverTime + return nil + } + parts := strings.Split(value, "/") + if len(parts) != 2 { + return fmt.Errorf("rate: value does not match rate syntax %q", value) + } + + // 52 because the uint needs to fit in a float64 + e, err := strconv.ParseUint(parts[0], 10, 52) + if err != nil { + return fmt.Errorf("rate: events part of rate value %q failed to parse as uint64: %w", value, err) + } + + d, err := time.ParseDuration(parts[1]) + if err != nil { + return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err) + } + + r.typ = BurstRateType + r.Events = float64(e) + r.OverTime = d + return nil +} + +func (r *Rate) String() string { + if r.OverTime == 0 { + return fmt.Sprintf("%f", r.Events) + } + return fmt.Sprintf("%d/%s", uint64(r.Events), r.OverTime.String()) +} diff --git a/internal/conf/rate_test.go b/internal/conf/rate_test.go new file mode 100644 index 000000000..a663a8802 --- /dev/null +++ b/internal/conf/rate_test.go @@ -0,0 +1,69 @@ +package conf + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateDecode(t *testing.T) { + cases := []struct { + str string + exp Rate + err string + }{ + {str: "1800", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "1800.0", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "3600/1h", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, + {str: "3600/1h0m0s", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, + {str: "100/24h", + exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}}, + {str: "", exp: Rate{}, + err: `rate: value does not match`}, + {str: "1h", exp: Rate{}, + err: `rate: value does not match`}, + {str: "/", exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "/1h", exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "3600.0/1h", exp: Rate{}, + err: `rate: events part of rate value "3600.0/1h" failed to parse`}, + {str: "100/", exp: Rate{}, + err: `rate: over-time part of rate value`}, + {str: "100/1", exp: Rate{}, + err: `rate: over-time part of rate value`}, + + // zero events + {str: "0/1h", + exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}}, + {str: "0/24h", + exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}}, + } + for idx, tc := range cases { + t.Logf("test #%v - duration str %v", idx, tc.str) + + var r Rate + err := r.Decode(tc.str) + require.Equal(t, tc.exp, r) // verify don't mutate r on errr + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + require.Equal(t, tc.exp, r) + require.Equal(t, tc.exp.typ, r.GetRateType()) + } + + // GetRateType() zero value + require.Equal(t, IntervalRateType, (&Rate{}).GetRateType()) + + // String() + require.Equal(t, "0.000000", (&Rate{}).String()) + require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String()) +} diff --git a/internal/conf/saml.go b/internal/conf/saml.go new file mode 100644 index 000000000..a38929ebe --- /dev/null +++ b/internal/conf/saml.go @@ -0,0 +1,151 @@ +package conf + +import ( + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + "net" + "net/url" + "time" +) + +// SAMLConfiguration holds configuration for native SAML support. +type SAMLConfiguration struct { + Enabled bool `json:"enabled"` + PrivateKey string `json:"-" split_words:"true"` + AllowEncryptedAssertions bool `json:"allow_encrypted_assertions" split_words:"true"` + RelayStateValidityPeriod time.Duration `json:"relay_state_validity_period" split_words:"true"` + + RSAPrivateKey *rsa.PrivateKey `json:"-"` + RSAPublicKey *rsa.PublicKey `json:"-"` + Certificate *x509.Certificate `json:"-"` + + ExternalURL string `json:"external_url,omitempty" split_words:"true"` + + RateLimitAssertion float64 `default:"15" split_words:"true"` +} + +func (c *SAMLConfiguration) GoString() string { return c.String() } +func (c *SAMLConfiguration) String() string { + if c == nil { + return "(*SAMLConfiguration)(nil)" + } + return fmt.Sprintf("SAMLConfiguration(Enabled: %v)", c.Enabled) +} + +func (c *SAMLConfiguration) Validate() error { + if c.Enabled { + bytes, err := base64.StdEncoding.DecodeString(c.PrivateKey) + if err != nil { + return errors.New("SAML private key not in standard Base64 format") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(bytes) + if err != nil { + return errors.New("SAML private key not in PKCS#1 format") + } + + if privateKey.E != 0x10001 { + return errors.New("SAML private key should use the 65537 (0x10001) RSA public exponent") + } + + if privateKey.N.BitLen() < 2048 { + return errors.New("SAML private key must be at least RSA 2048") + } + + if c.RelayStateValidityPeriod < 0 { + return errors.New("SAML RelayState validity period should be a positive duration") + } + + if c.ExternalURL != "" { + _, err := url.ParseRequestURI(c.ExternalURL) + if err != nil { + return err + } + } + } + + return nil +} + +// PopulateFields fills the configuration details based off the provided +// parameters. +func (c *SAMLConfiguration) PopulateFields(externalURL string) error { + // errors are intentionally ignored since they should have been handled + // within #Validate() + bytes, err := base64.StdEncoding.DecodeString(c.PrivateKey) + if err != nil { + return fmt.Errorf("saml: PopulateFields: invalid base64: %w", err) + } + + privateKey, err := x509.ParsePKCS1PrivateKey(bytes) + if err != nil { + return fmt.Errorf("saml: PopulateFields: invalid private key: %w", err) + } + + c.RSAPrivateKey = privateKey + c.RSAPublicKey = privateKey.Public().(*rsa.PublicKey) + + parsedURL, err := url.ParseRequestURI(externalURL) + if err != nil { + return fmt.Errorf("saml: unable to parse external URL for SAML, check API_EXTERNAL_URL: %w", err) + } + + host := "" + host, _, err = net.SplitHostPort(parsedURL.Host) + if err != nil { + host = parsedURL.Host + } + + // SAML does not care much about the contents of the certificate, it + // only uses it as a vessel for the public key; therefore we set these + // fixed values. + // Please avoid modifying or adding new values to this template as they + // will change the exposed SAML certificate, requiring users of + // GoTrue to re-establish a connection between their Identity Provider + // and their running GoTrue instances. + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0), + IsCA: false, + DNSNames: []string{ + "_samlsp." + host, + }, + KeyUsage: x509.KeyUsageDigitalSignature, + NotBefore: time.UnixMilli(0).UTC(), + NotAfter: time.UnixMilli(0).UTC().AddDate(200, 0, 0), + Subject: pkix.Name{ + CommonName: "SAML 2.0 Certificate for " + host, + }, + } + + if c.AllowEncryptedAssertions { + certTemplate.KeyUsage = certTemplate.KeyUsage | x509.KeyUsageDataEncipherment + } + return c.createCertificate(certTemplate) +} + +func (c *SAMLConfiguration) createCertificate(certTemplate *x509.Certificate) error { + certDer, err := x509.CreateCertificate(nil, certTemplate, certTemplate, c.RSAPublicKey, c.RSAPrivateKey) + if err != nil { + return err + } + return c.parseCertificateDer(certDer) +} + +func (c *SAMLConfiguration) parseCertificateDer(certDer []byte) error { + cert, err := x509.ParseCertificate(certDer) + if err != nil { + return err + } + + c.Certificate = cert + + if c.RelayStateValidityPeriod == 0 { + c.RelayStateValidityPeriod = 2 * time.Minute + } + return nil +} diff --git a/internal/conf/saml_test.go b/internal/conf/saml_test.go new file mode 100644 index 000000000..aa9c1262c --- /dev/null +++ b/internal/conf/saml_test.go @@ -0,0 +1,215 @@ +package conf + +import ( + "crypto/x509" + "encoding/base64" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSAMLConfiguration(t *testing.T) { + t.Run("String", func(t *testing.T) { + // string disabled + { + cfg := &SAMLConfiguration{Enabled: false} + const expStr = "SAMLConfiguration(Enabled: false)" + require.Equal(t, expStr, fmt.Sprintf("%v", cfg)) + require.Equal(t, expStr, fmt.Sprintf("%#v", cfg)) + } + + // string enabled + { + cfg := &SAMLConfiguration{Enabled: true} + const expStr = "SAMLConfiguration(Enabled: true)" + require.Equal(t, expStr, fmt.Sprintf("%v", cfg)) + require.Equal(t, expStr, fmt.Sprintf("%#v", cfg)) + } + + // string (nil) + { + var cfg *SAMLConfiguration + const expStr = "(*SAMLConfiguration)(nil)" + require.Equal(t, expStr, fmt.Sprintf("%v", cfg)) + require.Equal(t, expStr, fmt.Sprintf("%#v", cfg)) + } + }) + + t.Run("PopulateFields", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: validPrivateKey, + } + err := c.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + isSet := (c.Certificate.KeyUsage & x509.KeyUsageDataEncipherment) != 0 + require.False(t, isSet) + require.NotNil(t, c.RSAPrivateKey) + require.NotNil(t, c.RSAPublicKey) + require.NotNil(t, c.Certificate) + }) + + t.Run("PopulateFieldsEncryptedAssertions", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: validPrivateKey, + AllowEncryptedAssertions: true, + } + err := c.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + isSet := (c.Certificate.KeyUsage & x509.KeyUsageDataEncipherment) != 0 + require.True(t, isSet) + require.NotNil(t, c.RSAPrivateKey) + require.NotNil(t, c.RSAPublicKey) + require.NotNil(t, c.Certificate) + }) + + t.Run("PopulateFieldsInvalidExternalURL", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: "invalidprivatekey", + } + err := c.PopulateFields("\n") + require.Error(t, err) + }) + + t.Run("PopulateFieldsInvalidx509", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: validPrivateKey, + } + err := c.PopulateFields("http://invalid\nhost/foo") + require.Error(t, err) + }) + + t.Run("PopulateFieldsInvalidPKCS1", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("INVALID")), + } + err := c.PopulateFields("https://projectref.supabase.co") + require.Error(t, err) + }) + + t.Run("PopulateFieldInvalidCreateCertificate", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("INVALID")), + } + + tmpl := &x509.Certificate{} + err := c.createCertificate(tmpl) + require.Error(t, err) + }) + + t.Run("PopulateFieldInvalidCertificateDer", func(t *testing.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: validPrivateKey, + } + err := c.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + err = c.parseCertificateDer([]byte{0x0, 0x0}) + require.Error(t, err) + }) +} + +func TestSAMLConfigurationValidate(t *testing.T) { + invalidExamples := []*SAMLConfiguration{ + { + Enabled: true, + PrivateKey: "", + }, + { + Enabled: true, + PrivateKey: "InvalidBase64!", + }, + { + Enabled: true, + PrivateKey: validPrivateKey, + RelayStateValidityPeriod: -1, + }, + { + Enabled: true, + PrivateKey: validPrivateKey, + ExternalURL: "\n", + }, + { + Enabled: true, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("not PKCS#1")), + }, + { + Enabled: true, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("not PKCS#1")), + }, + { + // RSA 1024 key + Enabled: true, + PrivateKey: "MIICXQIBAAKBgQDFa3SgzWZpcoONv3Iq3FxNieks2u2TmykxxxeggI9aNpHpuCzwGQO8wqXGVvFNlkE3GSPcz7rklzfyj577Z47lfWdBP1OAefralA3tS2mafqpZ32JwDynX4as+xauLVdP4iOR96b3L2eOb6rDpr4wBJuNqO533xsjcbNPINEDkSwIDAQABAoGASggBtEtSHDjVHFKufWQlOO5+glOWw8Nrrz75nTaYizvre7mVIHRA8ogLolT4KCAwVHkY+bTsYMxULqGs/JnY+40suHECYQ2u76PTQlvJnhJANGtCxuV4lSK6B8QBJhjGExsnAOwMMKz0p5kVftx2GA+/Rz2De7DR9keNECjcAAECQQDtr5cdkEdnIffvi782843EvX/g8615AKZeUYVUl0gVXujjpIVZXDtytPHINvIW1Z2mOm2rlJukwiKYYJ8IjsxlAkEA1KGbJ9EI6AOUcnpy7FYdGkbINTDngCqVOoaddlHS+1SaofpYXZPueXXIqIG3viksxmq/Q0IY6+JRkGo/RpGq7wJARD+BAqok9oYYbR4RX7P7ZxyKlYsiqnX3T2nVAP8XYZuI/6SD7a7AGyW9ryGnzcq0o8BvMS9QqbRcvqgvwgNOyQJBAL2ZVMaOSIjKGGZz9WHz74Nstj1n3CWW0vYa7vGASMc/S5s/pefbbvvzIPfQo0z3XiuXJ/ELUTmU1vIVK1L7tRUCQQCsuE7xckZ8H/523jdWWy9zZVP1z4c5dVLDR5RY+YQNForgb6kkSv4Gzn/FRUOxqn2MEWJLla31D4EuS+XKuwZR", + }, + { + // RSA 2048 with 0x11 as public exponent + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAyMvTanPoiorCpIQCl70qXF34FIPOkKaInr1vw+3/0nik5CDUo761E02uTrK4/8JXr5NLGmy/fQmagNsBOdKewciRB3xxs+sPNncptG4rpCBjxSJdVl+mYZaw2kdvFY7TvNTlr7qG1Q0kV/3lBgpMlyM9OqBrjuG0UUzB5hlg08KLNflkQAkoJGWNVWULi2VceP3I3QsH9uNUQkgaM9Z6rl0BaRAkobHTTvquAqqj1AlNmSh24rrIbV4hYcNnesIpG4+LDd8XfpOwTp+jUl8akF6xcRBJjiPDJGN9ety29DcCxjo2i0b+TWYU+Pex08uOeOdulsgecbIVxLUEgRHcFQIBEQKCAQBefgkjCV5fUFuYtpfO75t2ws8Ytn9TITE7pHDUrDwmz1ynlvqnaM2uuyTZvYQ8HzhSn6rfQjv+mxuH7pcqRP9qQEQ/whdjuekKkm36jjKnlsWJ8g3OSyEe3YBmuDRGYVSVGOSO7l2Rb5ih4OQ/E+fOpyvfWoz38b5EYFs/GwBjpgJG+9cdCLYKOax8WDifWkjHdrogAlE8do/QF6RZoSvhAbRkpuxYActmKU8rIORrq8dLidSjBG2aoRH+RCN4ONZ3R4iHbYF2zWfqDFdSIX64kChaOZVhtTyTnF7/1v4VF3UwByEs8hTSckFH2jW6T7RZoatpgsv5zx/roRPDBWNRAoGBAPGphQwX9GF56XVmqD9rQMD9a0FuGCNGgiFkt2OUTFKmr4rTNVE8uJqEXjQwbybTnF2iJ1ApL1zNHg0cpnpMt7ZpcWG4Bu2UsXlwBL/ZwY4Spk6tHNdTsg/wuoWRSIGNanNS6CI5EUA4cxGNUt0G+dF4LaMHZuIAU7avs+kwDMzHAoGBANS1nS8KYkPUwYlmgVPNhMDTtjvq7fgP5UFDXnlhE6rJidc/+B0p9WiRhLGWlZebn+h2fELfIgK3qc4IzCHOkar0pic2D3bNbboNQKnqFl81hg0EORTK0JJ5/K4J61l5+rZtQu3Ss1HVwDiy9SKg6F3CQj9PK0r+hjtAStFSmZxDAoGBAMcEEzciyUE3OLsJP0NJRGKylJA8jFlJH99D4lIBqEQQzMyt76xQH45O5Cr6teO9U5hna6ttNhAwcxnbW+w/LeGEAwUuI9K2sEXjx60NrnUATLlDRO2QOElc1ddolhBWV6pERrLFlbxquR2DcWq6c2E1yzr3CW7TF8OfwVagCoqFAoGBAK8sJxeuMs5y+bxyiJ9d9NsItDFYD0TBy9tkqCe5W32W6fyPCJB86Df/XjflbCKAKVYHOSgDDPMt1yIlXNCL/326arbhOeld4eSDYm3P1jBKMijWTSAujaXN3yXqDRyCkjvhgmmAV3CR6Zga5/5mZQHrRZ2MfgGGUG0HxSTanJ7NAoGBAOhZBGtFsBdtEawvCh4Z8NaMC2nU+Ru9hEsZSy6rZQrPvja9aBUk5QUdh04TYtu8PzQ1EghZy71rtwDAvxXWJ1mWcZn0kD06tZKudmZpMVXCp3SFah6DDUCFSmQ2U60yh6XOzpS2+Z97Ngi02UFph8sSQA6Dl/lmaf4bfQHCYc5Z", + }, + } + + for i, example := range invalidExamples { + err := example.Validate() + require.Error(t, err, "Invalid example %d was regarded as valid", i) + } + + validExamples := []*SAMLConfiguration{ + { + Enabled: false, + }, + { + // RSA 2048 + Enabled: true, + PrivateKey: validPrivateKey, + }, + { + // RSA 3072 + Enabled: true, + PrivateKey: "MIIG4wIBAAKCAYEApYkvDaXJEDsELSVosc0sKFnoPeJai8sOu8di5ffGVJRr7mJi+VQjM0d2KeOIllVk2IV58M33Jz2Rx61NYPLu0N9fZqPwbgYn+FNz1L1xgslUL6gyaQnCEKtH5mRqPEBOPvAygq/fZ46eBMs3GSS6NWp/XF/iPaFc1mBDAZFvXev4XV7O6iuqz5mx3rQbkIhMjQxP+IOYWMS4TqueLJWgFUbij0FepJfOE+AlmfBa7xIOyE+g5t3vRB8XwzxRPsljlfgZXstxO1r1NS3DPiUj3kGYy7em5Yb+icIA6xzy0MiwU5RcBSwtVc+M/Yk2tMY6a9z1UX2M5Zr/ih3w0CbW6KDYplqgwwDZv2f+ynIqldn7SjVo3V6fWFu+KtRkofWWkTGjaU2DTpxrxUJEnEo6zXfBSejAjGGAJyKjX74uATlOu/LQEjd5umQpWYvtvP1UkbjHYgITtoTytb3uU7Q7W/YdtNUcaE377QHZF+E+XTCCCw00bCvpDciW+w0JSkRfAgMBAAECggGAR0jCKIBiC0k+zSo04YxXHbFJ34xgLZ7t41NDdYCzuayIpglcUb43wlddvUAsi4COgudH0bkAW7eZ1YD9t2gmC3CFpq+mU9r2z2swkEZcYVPNmxA1VSJMnd0Eg2Ruky+mAlhxh/GwpOm3hpz0RzGXtnT8D42C4cNhNTgS4tP8P1fkhmDTfef8EJZBEIRC8oSfYoYQ0hXpPyDHtakV3mE4pLD303T1CrAMoGaACsCEiDsgfoY75e9gn9c75mlNG1qhhJYxD3Sv1o9lQd3Q1A71sga/E+yIlUcPP4fDaA8DdeH+FHwL9xgQPd18gsrbPdbsg8JMLmjblaz8BB1MvJMwj+b3Ey2idD8CVIq5Ql97TebyMxZp3ZYjLq/R2ay+MpE9Vjgih096Hg+kCPMPi3Q9AmVJX8kN8+2zm2EeDoI/YnJFzmBcmaOuSBEGYdrRk5RCYfZMa1jvpoNUGbWzoX4gRfC7Gr+alaCWa9ot2c+ChWZQlpbKaMYMLU/VEd7gsf/BAoHBANJsSdIxrTUWpdm27SJlq5ylKVokS5ftrxk2gM8M16jLYE7SwPrhdxphWGH8/TMz6/Jyd+CuSfwAzyy325szlFlZVpxv8qu1vWROBaaaq1Tg8cqYC2s+hUTJLevcmiBHFu+7tiYNmMqkNIfj9/FN1zvfPVwqurtB5WXGjI4qhf5SyJgtj1GiM/s9Ae86LiRZhovcEEwf0LddGpMrUEDrWOV9D95sOMA00rsJXOfOg78Ms7Nq/h9w6cnD5x4jUJTMzwKBwQDJY/TMNVa1V8ci+pOMB6iTyi3azVC6ZiCXinCQS0oLebY1GmyWLv9A+/+9Wg/h4p4OdlZSA2/9f6+6njAcxI1wfzHVC3vgF7EDs9YUeAmXWBA171uPHbfimTd21utLkcyJ/WdO4OmKP7ZIK8UWyXE98N5NQV9NRX0sm6CJemwChcoJ8/7lsuYa4nJVUXtAkAMoj7e0nOoWn1IzyolmIXSTrBPiLWh68172tr3ciR6uGN3Yba6szkFTeaBDfNQvk3ECgcEAy07XkKBwwv+L5SxKOFbVlfc6Wh8Bbty2tnyjvemhoTRHbEFTNdOMeU+ezqZamhNLoKgazVp4n2TEx2cpZu5SInYgKew8Is3pHLYJ3axJaCwjUmTPe6Ifr5NVrDMsM42cSqsqVeADRZ+cJcQMtvhHwlByf8/FNdJ4a3qIKYBKkKy5pdc3R1+aK+AJM3QaSwK47f8FPBftWI07dQB/fQonjSvlnjkgKA2hohdszYgKYRhLtEnnGMfHCywd7U+ftvWfAoHAcxfq+SGqkiy+I+FsnWRrFTtAhYE9F6nyCmkV94Dvqis+1I5rbFEjk7Hw7/geh4uJpN5AatKIGCn29gIdoPM7mgU3J3hOrT0c7u7B9CS95n5vlUNb4iirxJanugUNp7yFVn85oTyse1P6CrjpBCLP0wRrJ1+q5XBHH005rBgIzlBDrPiCvidFlivAB75vX/BtvaqU5GWg6pjW0752U6XfB94Z5vLoeQvJQ9ogG39Jx1lyv5O/dgbSErC5xJf8c8whAoHAYdxLfZcDN2IEUgg/czE9jDSJ2EwOCE9HpCntthHAvceK3+PFfpCKwOLynqF8urhdeM510QJK2ETLvzpgMBgSh/klxeBYv8BCL8BuPwyPciAFmPE1Stx7C1+JBF2fayYkCSK9w85INLAJYKTDk9gE8O6l0bXA8tuq3F0tRTwMBcyEpMOehKFamoPcU6cnNa2HC+MyTOfXSBeNZ2VciFYf5rh3YrwoUYbQJtDXxFvoX0Ba+zyneNG0j3epXZuR2lyK", + }, + { + // RSA 4096 + Enabled: true, + PrivateKey: "MIIJKgIBAAKCAgEA2cNnNX4Be3jOKTr7lxIWxWfFKtwFqbWs9CZS7gDNXUtBlGuV1+FswPvSRKWEmwsBQikBfScowk4hL/JFgN8V25PijOk7eTPmw3tHuUhoil7GkJCMKhtrYwGbvINk1pK5mfI+V8GR3l52S779fg8nwktOtr99sLgfxUdxwxFY5hE5lo5P19QPClAA89SjQ3c/FlXy8R56/qf4u+Fuvd7Ecq7nQGeovsiSpBxY2gn4KL2LdkkyZmEQVgXzXjDGOOhF7M6eKim5MCsUqgHjCCkK7Gw9HNbd4oHNE5ucWRYjG1IpEYbYmep/9+wXgwQorYFKUT0NXrUv5H3VLQpsDyWDRZJ+wXGbwV2bRh2Z5bbAJVTxF8NaO8XujVZLIe+UJ8kUWj+n3hxwil9UU9yExR6M9TZBfHTKOVWcn1CquT85ppI0dtvlu3ToBwjjcd1wWLK8rLhmEwafC142bSL2kXLc6p7YrhTBN7PBPodQ2lLMg8xbw4cNspsMAPAPfrisqEYUGAs/EUScgcsSfmyzKNcdZlUx6UkMhz2F8sKPi4I4oIugxQiCa7LuSjmfrM6msIkrV+sj06zUYmAZzN+cf7rRlGFLNt1cKqqukjhbo9RL54XZQssT5GkHuVT6neyQBJX9EwtmZtXBTI78WTUabQhBcEBbxWbn5VodxDPXmfAiumsCAwEAAQKCAgEAnU1ux5BPJ877NYNa7CTv+AdewPgQyyfmWLM6YpyHvKW5KKqSolA/jCQcHuRlps3LSexvG+Xmpn1jscvTcyUzF9t64ok0Ifhg8MKj6+6nPZT64MDZzyzhZLJrukA73lg85Dy91gyI/1XDJDJB0QbHlK1rnc0z0S0gHhTe06c7TW4R6HTCrkiL2Moz9e6bRQfltY++n3iCJmRV4/oTUeqSg7leaQK4PaCLdSrY8CAVd/B7xqVXV+czssA3rcmT1tXKdSZH0HM1R9tG4Qvd4S4sqt4BQ0zfGVjkOA7HYP8BuyGdcwCyhHSFniSYU1b0v2jOs2Jjvw8pGmffTtrhdguGB60rMocKyfXvRxjJmIXZae6W8ZCwz76rKr8igXZUXvK3LqhGfm5fDpvWQlX8ugnwWOmowJqToS/fVKwhjFjsPONRbRZh7MTebRjx9ErpQycTm0SiUrUA/WE8Na1JeelTjxThCuy1VjIOtYVk4eYGP6REQV+nYGGuD7ruR+dpD4UR3/2DsPLik8X+YUFMjGCr+LjzybDj8Ux+a/u/eKD3rIe45PooJzGR/s+RCcwtAIue29+C+2uj3lAypEIqRGd2k0RgEw8Cj43Omc3Pyf+M3IbKfpE82OGSPp/rgHIfJSwGuOWH09yxCjyqY9H/wtxea6qOpeuk/g4ipaTp/QvZikkCggEBAPeowAf5hz16Oreb4D1c9MoHdofO/Kv8D8Kig9YhwDGck8uG4TbdBMSre5mOGSykZBA566/YHf49Pdi5Eq3L5oEdNKJrxn35WlMHY2pdOCrhWfqe+Z3Fg6qlhQTFU0blFAwy6NUixHP7xsLyAdpjkSxdsQzOaHUMII8w0gD+/AqSq3c/sC9AF+CeiZQV0P53eseNVfxfv8f1aDH7JcywG4P6Xe9pdHoNW93u2j2HQcrLidOtsT5s8iXj2YO3d4YZg/I20dViC7+DrG1ep+rfiuYY5VS1jKVqTknzKHlP7OHOaYJhDPAffnNFBWj4Th11NKxigpx3ogXO9jVyCGXWwD0CggEBAOEY5hvGEufmWx821Oln3OoaUv4MBSa0RMuuBZyx7mx18gfidjB9exC6sFduTwdFjnM8PUDBnOo2ZM7ThcCbTJ4Qi7LB5gDAXUbJqJk7o+lKrfXcpYdksoXWHmWAT7RE1v9nbXle1KHKIaaga/I8hVtSfeTizb8y+dDP3T3H8tVByvneAE0LnDVmr1VhFppKnzWl5vTY2Y+6XGIWmrCuWS1+zf+dx32zJ2ZOfT1Wwk20igC79RzH0sDHSv7DNyUn9u/9LtjIIrDtWch9+5Xkq0uZQAqM0Jw/QUYqarJSNNVhREmwWk+B6sJaQUN26YyTHiOpfFu1RUwHyyg58L8yJ8cCggEBALqSqnhXh4bM+kcwavJPgSpiDO2rBbcbIVRj0iYTLxMw/jap2ijWwKzY8zhvUI/NGIUQ3XmPuqi5wknuwx+jKHfEZM6nmtV0cJN0UXTj3ViQhJTGBw7Qqax5HYjGj0Itebjm8Xj/xDgMSWS7pKG9uLRPsP4Q0ai8BhtZkBun/ICKlho0JKq0Akj5pnOlK9lIcXq8AzcpevVM774XkhZt5Yy7pOCj9VetkLPVKRyJNQtt4ttRUuHQeWwKBuev459mwXxLyDCUuH0C2Xdbg+zxk1ZdEweJ7fb/6xLS2H7rs205b0sFihWr5Ds6mCTISzDuB0yGuhbeGXV+wQTqb2EpM5ECggEBAMBFsGiQ7J1BWxxyjbNBkKY3DiUKx2ukGA+S+iA6rFng9XherG4HARPtI6vLAZ5If8FW90tVFl/JTpqMe3dmMC/kGi/7CCgkKIjKwEUDeKNRsv6MFqhsD0Ha/+Pbkjl9g9ht1EkUA7SfH9dguFQV9iNndzoHsY9cT59ZrrWTEY2vwV1lkAQ/opLKv4HCiLgKfawppfoHMO9gVIFEpaW9h1chNXzenQR1/3WYHcpDTX1qdWbjJiALX65jjV/ICFaoqHmeXmG1skxGsaZcVoZW6SqOIPHiDl8oeO0iVjkzlwWdK+N1y+6WHp0c0xp5fE0jbV8w6pS7ZhHnplUaCNaIVQkCggEAUcQ0VhA1FlWT/mcbTp3v3ojCkZH8fkbNQoL5nsxb+72VDGB0YYZBXyBCApY2E/3jfAH8nG0ALAVS1s983USx17Z2+Z+Yg13Gt1dglV0JC2vMS6j83G0UxsKdcFyafbgJl+hrvBAuOoqPLd5r5F6VnDZeDDsJ3Y6ZTmcsbf8EZkUSxT80oKBLwXi06dfnEz7nYUxvqk54QG3xN1VJAQoKaJ9sH9pbAPdA0GxRx8DIWBd3UhMFJbdIplfGlkk9kf+E1k6Z2SaRB8QQHpvdgsdQ6YXPV+0ejhiGytX9DMSmjZe3dC4C7ZdaCL+kSxdFRgIo2KAcJVdpsqbw/hclfNY7cQ==", + }, + } + + for i, example := range validExamples { + err := example.Validate() + require.NoError(t, err, "Valid example %d was regarded as invalid", i) + } +} + +func TestSAMLConfigurationDeterministicCertificate(t *testing.T) { + a := &SAMLConfiguration{ + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAt7dS8iM5MsQ+1mVkNpoaUnL8BCdxSrSx8jsSnvqN/GIJ4ipqbdrTgLpFVklVTqfaa5CykGVEV577l6AWkpkm2p7SvSkCQglmyAMMjY9glmztytAnfBpm+cQ6ZVTHC4XKlUG1aJigEuXPcZUU3FiBHWEuV2huYy2bLOtIY1v9N0i2v61QCdG+SM/Yb5t86KzApRl7VyHqquge6vvRuchfF0msv/2LW32hwxg3Gt4zkAF0SJqCCcfAPZ9pQwmbdUhoX16dRFU98nyIvuR8LH/wONZe/YyywFFHDEwkFa4XEzjCEm+AD+xvK7eEu55w21xB8JKMLEBy8uRuI3bIEG4pawIDAQABAoIBADw4IT4xgYw8e4R3U7P6K2qfOjB6ZU5hkHqgFmh6JJR35ll2IdDEi9OEOzofa5EOwC/GDGH8b7xw5nM7DGsdPHko2lca3BydTE1/glvchYKJTiDOvkKVvO9d/O4+Lch/IHpwQXB5pu7K2YaXoXDgqeHhevk3yAdGabj9norDGmtGIeU/x1hialKbw6L080CdbxpjeAsM/w+G/VtwvyOKYFBYxBflRW+sS8UeclVqKRAvaXKd1JGleWzH3hFZyFI54x5LyyjPI1JyVXRjNbf8xcS6eRaN849grL1+wBxEs/lQFn4JLhAcNi912iJ3lhxvkNleXZw7B7JAM8x4wUbK7zECgYEA6SYmu3YH8XuLUfT8MMCp+ETjPkNMOJGQmTXOkW6zuXP3J8iCPIxtuz09cGIro+yJU23yPUzOVCDZMmnMWBmkoTKAFoFL9TX0Eyqn/t1MD77i3NdkMp16yI5fwOO6yX1bZgLiG00W2E5/IGgNfTtEafU/mre95JBnTgxS3sAvz8UCgYEAybjfBVt+1X0vSVAGKYHI9wtzoSx3dIGE8G5LIchPTdNDZ0ke0QCRffhyCGKy6bPos0P2z5nLgWSePBPZQowpwZiQVXdWE05ID641E2zGULdYL1yVHDt6tVTpSzTAy89BiS1G8HvgpQyaBTmvmF11Fyd/YbrDxEIHN+qQdDkM928CgYEA4lJ4ksz21QF6sqpADQtZc3lbplspqFgVp8RFq4Nsz3+00lefpSskcff2phuGBXBdtjEqTzs5pwzkCj4NcRAjcZ9WG4KTu4sOTXTA83TamwZPrtUfnMqmH/2lEdd+wI0BpjryRlJE9ODuIwUe4wwfU0QQ5B2tJizPO0JXR4gEYYkCgYBzqidm4QGm1DLq7JG79wkObmiMv/x2t1VMr1ExO7QNQdfiP1EGMjc6bdyk5kMEMf5527yHaP4BYXpBpHfs6oV+1kXcW6LlSvuS0iboznQgECDmd0WgfJJtqxRh5QuvUVWYnHeSqNU0jjc6S8tdqCjdb+5gUUCzJdERxNOzcIr4zQKBgAqcBQwlWy0PdlZ06JhJUYlwX1pOU8mWPz9LIF0wrSm9LEtAl37zZJaD3uscvk/fCixAGHOktkDGVO7aUYIAlX9iD49huGkeRTn9tz7Wanw6am04Xj0y7H1oPPV7k5nJ4s9AOWq/gkZEhrRIis2anAczsx1YHSjq/M05+AbuRzvs", + } + + b := &SAMLConfiguration{ + Enabled: a.Enabled, + PrivateKey: a.PrivateKey, + } + + err := a.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + err = b.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + require.Equal(t, a.Certificate.Raw, b.Certificate.Raw, "Certificate generation should be deterministic") +} + +const ( + validPrivateKey = "MIIEowIBAAKCAQEAsBuxTUWFrfy0qYXaqNSeVWcJOd6TQ4+4b/3N4p/58r1d/kMU+K+BGR+tF0GKHGYngTF6puvNDff2wgW3dp3LUSMjxOhC3sK0uL90vd+IR6v1EDDGLyQNo6EjP/x5Gp/PcL2s6hZb8iLBEq4FksPnEhWqf9Nsmgf1YPJV4AvaaWe3oBFo9zJobSs3etTVitc3qEH2DpgYFtrCKhMWv5qoZtZTyZRE3LU3rvInDgYw6HDGF1G4y4Fvah6VpRmTdyMR81r1tCLmGvk61QJp7i4HteazQ6Raqh2EZ1sH/UfEp8mrwYRaRdgLDQ/Q6/YlO8NTQwzp6YwwAybhMBnOrABLCQIDAQABAoIBADqobq0DPByQsIhKmmNjtn1RvYP1++0kANXknuAeUv2kT5tyMpkGtCRvJZM6dEszR3NDzMuufPVrI1jK2Kn8sw0KfE6I4kUaa2Gh+7uGqfjdcNn8tPZctuJKuNgGOzxAALNXqjGqUuPa6Z5UMm0JLX0blFfRTzoa7oNlFG9040H6CRjJQQGfYyPS8xeo+RUR009sK/222E5jz6ThIiCrOU/ZGm5Ws9y3AAIASqJd9QPy7qxKoFZ1qKZ/cDaf1txCKq9VBXH6ypZoU1dQibhyLCIJ3tYapBtV4p8V12oHhITXb6Vbo1P9bQSVz+2rQ0nJkjdXX/N4aHE01ecbu8MpMxUCgYEA5P4ZCAdpkTaOSJi7GyL4AcZ5MN26eifFnRO/tbmw07f6vi//vdqzC9T7kxmZ8e1OvhX5OMGNb3nsXm78WgS2EVLTkaTInG6XhlOeYj9BHAQZDBr7rcAxrVQxVgaGDiZpYun++kXw+39iq3gxuYuC9mM0AQze3SjTRIM9WWXJSqMCgYEAxODfXcWMk2P/WfjE3u+8fhjc3cvqyWSyThEZC9YzpN59dL73SE7BRkMDyZO19fFvVO9mKsRfsTio0ceC5XQOO6hUxAm4gAEvMpeapQgXTxIxF5FAQ0vGmBMxT+xg7lX8HTTJX/UCttKo3BdIJQeTf8bKVzJCoLFh8Rcv5qI6umMCgYAEuj44DTcfuVmcpBKQz9sA5mEQIjO8W9/Xi1XU4Z2F8XFqxcDo4X/6yY3cDpZACV8ry3ZWtqA94e2AUZhCH4DGwMf/ZMCDgkD8k/NcIeQtOORvfIsfni0oX+mY1g+kcSSR1zTdY95CwvF9isC0DO5KOegT8XkUZchezLrSgqhyMwKBgQCvS0mWRH6V/UMu6MDhfrNl0t1U3mt+RZo8yBx03ZO+CBvMBvxF9VlBJgoJQOuSwBVQmpdtHMvXD4vAvNNfWaYSmB5hLgaIcoWDlliq+DlIvfnX8gw13xJD9VLCxsTHcOe5WXazaYOxJIAU9uXVkplR+73NRYLtcQKzluGfiHKh4QKBgFpPtOqcAbkMsV+1qPYvvvX7E4+l52Odb4tbxGBYV8tzCqMRETqMPVxFWwsj+EQ8lyAu15rCRH7DKHVK5zL6JvIZEjt0tptKqSL2o3ovS6y3DmD6t+YpvjKME7a+vunOoJWe9pWl3wZmodfyZMpAdDLvDGhPR7Jlhun41tbMMaQF" +) diff --git a/internal/conf/tracing.go b/internal/conf/tracing.go new file mode 100644 index 000000000..9a1d9bee8 --- /dev/null +++ b/internal/conf/tracing.go @@ -0,0 +1,33 @@ +package conf + +type TracingExporter = string + +const ( + OpenTelemetryTracing TracingExporter = "opentelemetry" +) + +type TracingConfig struct { + Enabled bool + Exporter TracingExporter `default:"opentelemetry"` + + // ExporterProtocol is the OTEL_EXPORTER_OTLP_PROTOCOL env variable, + // only available when exporter is opentelemetry. See: + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md + ExporterProtocol string `default:"http/protobuf" envconfig:"OTEL_EXPORTER_OTLP_PROTOCOL"` + + // Host is the host of the OpenTracing collector. + Host string + + // Port is the port of the OpenTracing collector. + Port string + + // ServiceName is the service name to use with OpenTracing. + ServiceName string `default:"gotrue" split_words:"true"` + + // Tags are the tags to associate with OpenTracing. + Tags map[string]string +} + +func (tc *TracingConfig) Validate() error { + return nil +} diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go new file mode 100644 index 000000000..bb8688bb6 --- /dev/null +++ b/internal/crypto/crypto.go @@ -0,0 +1,169 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base32" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + + "golang.org/x/crypto/hkdf" +) + +// GenerateOtp generates a random n digit otp +func GenerateOtp(digits int) string { + upper := math.Pow10(digits) + val := must(rand.Int(rand.Reader, big.NewInt(int64(upper)))) + + // adds a variable zero-padding to the left to ensure otp is uniformly random + expr := "%0" + strconv.Itoa(digits) + "v" + otp := fmt.Sprintf(expr, val.String()) + + return otp +} +func GenerateTokenHash(emailOrPhone, otp string) string { + return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) +} + +// Generated a random secure integer from [0, max[ +func secureRandomInt(max int) int { + randomInt := must(rand.Int(rand.Reader, big.NewInt(int64(max)))) + return int(randomInt.Int64()) +} + +type EncryptedString struct { + KeyID string `json:"key_id"` + Algorithm string `json:"alg"` + Data []byte `json:"data"` + Nonce []byte `json:"nonce,omitempty"` +} + +func (es *EncryptedString) IsValid() bool { + return es.KeyID != "" && len(es.Data) > 0 && len(es.Nonce) > 0 && es.Algorithm == "aes-gcm-hkdf" +} + +// ShouldReEncrypt tells you if the value encrypted needs to be encrypted again with a newer key. +func (es *EncryptedString) ShouldReEncrypt(encryptionKeyID string) bool { + return es.KeyID != encryptionKeyID +} + +func (es *EncryptedString) Decrypt(id string, decryptionKeys map[string]string) ([]byte, error) { + decryptionKey := decryptionKeys[es.KeyID] + + if decryptionKey == "" { + return nil, fmt.Errorf("crypto: decryption key with name %q does not exist", es.KeyID) + } + + key, err := deriveSymmetricKey(id, es.KeyID, decryptionKey) + if err != nil { + return nil, err + } + + block := must(aes.NewCipher(key)) + cipher := must(cipher.NewGCM(block)) + + decrypted, err := cipher.Open(nil, es.Nonce, es.Data, nil) // #nosec G407 + if err != nil { + return nil, err + } + + return decrypted, nil +} + +func ParseEncryptedString(str string) *EncryptedString { + if !strings.HasPrefix(str, "{") { + return nil + } + + var es EncryptedString + + if err := json.Unmarshal([]byte(str), &es); err != nil { + return nil + } + + if !es.IsValid() { + return nil + } + + return &es +} + +func (es *EncryptedString) String() string { + out := must(json.Marshal(es)) + + return string(out) +} + +func deriveSymmetricKey(id, keyID, keyBase64URL string) ([]byte, error) { + hkdfKey, err := base64.RawURLEncoding.DecodeString(keyBase64URL) + if err != nil { + return nil, err + } + + if len(hkdfKey) != 256/8 { + return nil, fmt.Errorf("crypto: key with ID %q is not 256 bits", keyID) + } + + // Since we use AES-GCM here, the same symmetric key *must not be used + // more than* 2^32 times. But, that's not that much. Suppose a system + // with 100 million users, then a user can only change their password + // 42 times. To prevent this, the actual symmetric key is derived by + // using HKDF using the encryption key and the "ID" of the object + // containing the encryption string. Ideally this ID is a UUID. This + // has the added benefit that the encrypted string is bound to that + // specific object, and can't accidentally be "moved" to other objects + // without changing their ID to the original one. + + keyReader := hkdf.New(sha256.New, hkdfKey, nil, []byte(id)) + key := make([]byte, 256/8) + + must(io.ReadFull(keyReader, key)) + + return key, nil +} + +func NewEncryptedString(id string, data []byte, keyID string, keyBase64URL string) (*EncryptedString, error) { + key, err := deriveSymmetricKey(id, keyID, keyBase64URL) + if err != nil { + return nil, err + } + + block := must(aes.NewCipher(key)) + cipher := must(cipher.NewGCM(block)) + + es := EncryptedString{ + KeyID: keyID, + Algorithm: "aes-gcm-hkdf", + Nonce: make([]byte, 12), + } + + must(io.ReadFull(rand.Reader, es.Nonce)) + es.Data = cipher.Seal(nil, es.Nonce, data, nil) // #nosec G407 + + return &es, nil +} + +// SecureAlphanumeric generates a secure random alphanumeric string using standard library +func SecureAlphanumeric(length int) string { + if length < 8 { + length = 8 + } + + // Calculate bytes needed for desired length + // base32 encoding: 5 bytes -> 8 chars + numBytes := (length*5 + 7) / 8 + + b := make([]byte, numBytes) + must(io.ReadFull(rand.Reader, b)) + + // Use standard library's base32 without padding + return strings.ToLower(base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b))[:length] +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 000000000..4541c91fc --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,108 @@ +package crypto + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestEncryptedStringPositive(t *testing.T) { + id := uuid.Must(uuid.NewV4()).String() + + es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") + assert.NoError(t, err) + + assert.Equal(t, es.KeyID, "key-id") + assert.Equal(t, es.Algorithm, "aes-gcm-hkdf") + assert.Len(t, es.Data, 20) + assert.Len(t, es.Nonce, 12) + + dec := ParseEncryptedString(es.String()) + + assert.NotNil(t, dec) + assert.Equal(t, dec.Algorithm, "aes-gcm-hkdf") + assert.Len(t, dec.Data, 20) + assert.Len(t, dec.Nonce, 12) + + decrypted, err := dec.Decrypt(id, map[string]string{ + "key-id": "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4", + }) + + assert.NoError(t, err) + assert.Equal(t, []byte("data"), decrypted) +} + +func TestParseEncryptedStringNegative(t *testing.T) { + negativeExamples := []string{ + "not-an-encrypted-string", + // not json + "{{", + // not parsable json + `{"key_id":1}`, + `{"alg":1}`, + `{"data":"!!!"}`, + `{"nonce":"!!!"}`, + // not valid + `{}`, + `{"key_id":"key_id"}`, + `{"key_id":"key_id","alg":"different","data":"AQAB=","nonce":"AQAB="}`, + } + + for _, example := range negativeExamples { + assert.Nil(t, ParseEncryptedString(example)) + } +} + +func TestEncryptedStringDecryptNegative(t *testing.T) { + id := uuid.Must(uuid.NewV4()).String() + + // short key + _, err := NewEncryptedString(id, []byte("data"), "key-id", "short_key") + assert.Error(t, err) + + // not base64 + _, err = NewEncryptedString(id, []byte("data"), "key-id", "!!!") + assert.Error(t, err) + + es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") + assert.NoError(t, err) + + dec := ParseEncryptedString(es.String()) + assert.NotNil(t, dec) + + _, err = dec.Decrypt(id, map[string]string{ + // empty map + }) + assert.Error(t, err) + + // short key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "AQAB", + }) + assert.Error(t, err) + + // key not base64 + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "!!!", + }) + assert.Error(t, err) + + // bad key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + }) + assert.Error(t, err) + + // bad tag for AEAD failure + dec.Data[len(dec.Data)-1] += 1 + + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4", + }) + assert.Error(t, err) +} + +func TestSecureToken(t *testing.T) { + assert.Equal(t, len(SecureAlphanumeric(22)), 22) +} diff --git a/internal/crypto/password.go b/internal/crypto/password.go new file mode 100644 index 000000000..7cf46073d --- /dev/null +++ b/internal/crypto/password.go @@ -0,0 +1,418 @@ +package crypto + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/supabase/auth/internal/observability" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/scrypt" +) + +type HashCost = int + +const ( + // DefaultHashCost represents the default + // hashing cost for any hashing algorithm. + DefaultHashCost HashCost = iota + + // QuickHashCosts represents the quickest + // hashing cost for any hashing algorithm, + // useful for tests only. + QuickHashCost HashCost = iota + + Argon2Prefix = "$argon2" + FirebaseScryptPrefix = "$fbscrypt" + FirebaseScryptKeyLen = 32 // Firebase uses AES-256 which requires 32 byte keys: https://pkg.go.dev/golang.org/x/crypto/scrypt#Key +) + +// PasswordHashCost is the current pasword hashing cost +// for all new hashes generated with +// GenerateHashFromPassword. +var PasswordHashCost = DefaultHashCost + +var ( + generateFromPasswordSubmittedCounter = observability.ObtainMetricCounter("gotrue_generate_from_password_submitted", "Number of submitted GenerateFromPassword hashing attempts") + generateFromPasswordCompletedCounter = observability.ObtainMetricCounter("gotrue_generate_from_password_completed", "Number of completed GenerateFromPassword hashing attempts") +) + +var ( + compareHashAndPasswordSubmittedCounter = observability.ObtainMetricCounter("gotrue_compare_hash_and_password_submitted", "Number of submitted CompareHashAndPassword hashing attempts") + compareHashAndPasswordCompletedCounter = observability.ObtainMetricCounter("gotrue_compare_hash_and_password_completed", "Number of completed CompareHashAndPassword hashing attempts") +) + +var ErrArgon2MismatchedHashAndPassword = errors.New("crypto: argon2 hash and password mismatch") +var ErrScryptMismatchedHashAndPassword = errors.New("crypto: fbscrypt hash and password mismatch") + +// argon2HashRegexp https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#argon2-encoding +var argon2HashRegexp = regexp.MustCompile("^[$](?Pargon2(d|i|id))[$]v=(?P(16|19))[$]m=(?P[0-9]+),t=(?P[0-9]+),p=(?P

[0-9]+)(,keyid=(?P[^,$]+))?(,data=(?P[^$]+))?[$](?P[^$]*)[$](?P.*)$") +var fbscryptHashRegexp = regexp.MustCompile(`^\$fbscrypt\$v=(?P[0-9]+),n=(?P[0-9]+),r=(?P[0-9]+),p=(?P

[0-9]+)(?:,ss=(?P[^,]+))?(?:,sk=(?P[^$]+))?\$(?P[^$]+)\$(?P.+)$`) + +type Argon2HashInput struct { + alg string + v string + memory uint64 + time uint64 + threads uint64 + keyid string + data string + salt []byte + rawHash []byte +} + +type FirebaseScryptHashInput struct { + v string + memory uint64 + rounds uint64 + threads uint64 + saltSeparator []byte + signerKey []byte + salt []byte + rawHash []byte +} + +// See: https://github.com/firebase/scrypt for implementation +func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { + submatch := fbscryptHashRegexp.FindStringSubmatchIndex(hash) + if submatch == nil { + return nil, errors.New("crypto: incorrect scrypt hash format") + } + + v := string(fbscryptHashRegexp.ExpandString(nil, "$v", hash, submatch)) + n := string(fbscryptHashRegexp.ExpandString(nil, "$n", hash, submatch)) + r := string(fbscryptHashRegexp.ExpandString(nil, "$r", hash, submatch)) + p := string(fbscryptHashRegexp.ExpandString(nil, "$p", hash, submatch)) + ss := string(fbscryptHashRegexp.ExpandString(nil, "$ss", hash, submatch)) + sk := string(fbscryptHashRegexp.ExpandString(nil, "$sk", hash, submatch)) + saltB64 := string(fbscryptHashRegexp.ExpandString(nil, "$salt", hash, submatch)) + hashB64 := string(fbscryptHashRegexp.ExpandString(nil, "$hash", hash, submatch)) + + if v != "1" { + return nil, fmt.Errorf("crypto: Firebase scrypt hash uses unsupported version %q only version 1 is supported", v) + } + memoryPower, err := strconv.ParseUint(n, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n parameter %q %w", n, err) + } + if memoryPower == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n=0") + } + rounds, err := strconv.ParseUint(r, 10, 64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r parameter %q: %w", r, err) + } + if rounds == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r=0") + } + + threads, err := strconv.ParseUint(p, 10, 8) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p parameter %q %w", p, err) + } + if threads == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p=0") + } + + rawHash, err := base64.StdEncoding.DecodeString(hashB64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid base64 in the hash section %w", err) + } + + salt, err := base64.StdEncoding.DecodeString(saltB64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt salt has invalid base64 in the hash section %w", err) + } + + var saltSeparator, signerKey []byte + if signerKey, err = base64.StdEncoding.DecodeString(sk); err != nil { + return nil, err + } + if saltSeparator, err = base64.StdEncoding.DecodeString(ss); err != nil { + return nil, err + } + + input := &FirebaseScryptHashInput{ + v: v, + memory: uint64(1) << memoryPower, + rounds: rounds, + threads: threads, + salt: salt, + rawHash: rawHash, + saltSeparator: saltSeparator, + signerKey: signerKey, + } + + return input, nil +} + +func ParseArgon2Hash(hash string) (*Argon2HashInput, error) { + submatch := argon2HashRegexp.FindStringSubmatchIndex(hash) + if submatch == nil { + return nil, errors.New("crypto: incorrect argon2 hash format") + } + + alg := string(argon2HashRegexp.ExpandString(nil, "$alg", hash, submatch)) + v := string(argon2HashRegexp.ExpandString(nil, "$v", hash, submatch)) + m := string(argon2HashRegexp.ExpandString(nil, "$m", hash, submatch)) + t := string(argon2HashRegexp.ExpandString(nil, "$t", hash, submatch)) + p := string(argon2HashRegexp.ExpandString(nil, "$p", hash, submatch)) + keyid := string(argon2HashRegexp.ExpandString(nil, "$keyid", hash, submatch)) + data := string(argon2HashRegexp.ExpandString(nil, "$data", hash, submatch)) + saltB64 := string(argon2HashRegexp.ExpandString(nil, "$salt", hash, submatch)) + hashB64 := string(argon2HashRegexp.ExpandString(nil, "$hash", hash, submatch)) + + if alg != "argon2i" && alg != "argon2id" { + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported algorithm %q only argon2i and argon2id supported", alg) + } + + if v != "19" { + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported version %q only %d is supported", v, argon2.Version) + } + + if data != "" { + return nil, fmt.Errorf("crypto: argon2 hashes with the data parameter not supported") + } + + if keyid != "" { + return nil, fmt.Errorf("crypto: argon2 hashes with the keyid parameter not supported") + } + + memory, err := strconv.ParseUint(m, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid m parameter %q %w", m, err) + } + + time, err := strconv.ParseUint(t, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid t parameter %q %w", t, err) + } + + threads, err := strconv.ParseUint(p, 10, 8) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid p parameter %q %w", p, err) + } + + rawHash, err := base64.RawStdEncoding.DecodeString(hashB64) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the hash section %w", err) + } + if len(rawHash) == 0 { + return nil, errors.New("crypto: argon2 hash is empty") + } + + salt, err := base64.RawStdEncoding.DecodeString(saltB64) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the salt section %w", err) + } + if len(salt) == 0 { + return nil, errors.New("crypto: argon2 salt is empty") + } + + input := Argon2HashInput{ + alg: alg, + v: v, + memory: memory, + time: time, + threads: threads, + keyid: keyid, + data: data, + salt: salt, + rawHash: rawHash, + } + + return &input, nil +} + +func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) error { + input, err := ParseArgon2Hash(hash) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", input.alg), + attribute.String("v", input.v), + attribute.Int64("m", int64(input.memory)), + attribute.Int64("t", int64(input.time)), + attribute.Int("p", int(input.threads)), + attribute.Int("len", len(input.rawHash)), + } // #nosec G115 + + var match bool + var derivedKey []byte + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool( + "match", + match, + )) + + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + switch input.alg { + case "argon2i": + derivedKey = argon2.Key([]byte(password), input.salt, uint32(input.time), uint32(input.memory), uint8(input.threads), uint32(len(input.rawHash))) // #nosec G115 + + case "argon2id": + derivedKey = argon2.IDKey([]byte(password), input.salt, uint32(input.time), uint32(input.memory), uint8(input.threads), uint32(len(input.rawHash))) // #nosec G115 + } + + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 + + if !match { + return ErrArgon2MismatchedHashAndPassword + } + + return nil +} + +func compareHashAndPasswordFirebaseScrypt(ctx context.Context, hash, password string) error { + input, err := ParseFirebaseScryptHash(hash) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("v", input.v), + attribute.Int64("n", int64(input.memory)), + attribute.Int64("r", int64(input.rounds)), + attribute.Int("p", int(input.threads)), + attribute.Int("len", len(input.rawHash)), + } // #nosec G115 + + var match bool + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool("match", match)) + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + derivedKey := firebaseScrypt([]byte(password), input.salt, input.signerKey, input.saltSeparator, input.memory, input.rounds, input.threads) + + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 + if !match { + return ErrScryptMismatchedHashAndPassword + } + + return nil +} + +func firebaseScrypt(password, salt, signerKey, saltSeparator []byte, memCost, rounds, p uint64) []byte { + ck := must(scrypt.Key(password, append(salt, saltSeparator...), int(memCost), int(rounds), int(p), FirebaseScryptKeyLen)) // #nosec G115 + block := must(aes.NewCipher(ck)) + + cipherText := make([]byte, aes.BlockSize+len(signerKey)) + + // #nosec G407 -- Firebase scrypt requires deterministic IV for consistent results. See: JaakkoL/firebase-scrypt-python@master/firebasescrypt/firebasescrypt.py#L58 + stream := cipher.NewCTR(block, cipherText[:aes.BlockSize]) + stream.XORKeyStream(cipherText[aes.BlockSize:], signerKey) + + return cipherText[aes.BlockSize:] +} + +// CompareHashAndPassword compares the hash and +// password, returns nil if equal otherwise an error. Context can be used to +// cancel the hashing if the algorithm supports it. +func CompareHashAndPassword(ctx context.Context, hash, password string) error { + if strings.HasPrefix(hash, Argon2Prefix) { + return compareHashAndPasswordArgon2(ctx, hash, password) + } else if strings.HasPrefix(hash, FirebaseScryptPrefix) { + return compareHashAndPasswordFirebaseScrypt(ctx, hash, password) + } + + // assume bcrypt + hashCost, err := bcrypt.Cost([]byte(hash)) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", "bcrypt"), + attribute.Int("bcrypt_cost", hashCost), + } + + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool( + "match", + !errors.Is(err, bcrypt.ErrMismatchedHashAndPassword), + )) + + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err +} + +// GenerateFromPassword generates a password hash from a +// password, using PasswordHashCost. Context can be used to cancel the hashing +// if the algorithm supports it. +func GenerateFromPassword(ctx context.Context, password string) (string, error) { + hashCost := bcrypt.DefaultCost + + switch PasswordHashCost { + case QuickHashCost: + hashCost = bcrypt.MinCost + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", "bcrypt"), + attribute.Int("bcrypt_cost", hashCost), + } + + generateFromPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer generateFromPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + + hash := must(bcrypt.GenerateFromPassword([]byte(password), hashCost)) + + return string(hash), nil +} + +func GeneratePassword(requiredChars []string, length int) string { + passwordBuilder := strings.Builder{} + passwordBuilder.Grow(length) + + // Add required characters + for _, group := range requiredChars { + if len(group) > 0 { + randomIndex := secureRandomInt(len(group)) + + passwordBuilder.WriteByte(group[randomIndex]) + } + } + + // Define a default character set for random generation (if needed) + const allChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // Fill the rest of the password + for passwordBuilder.Len() < length { + randomIndex := secureRandomInt(len(allChars)) + passwordBuilder.WriteByte(allChars[randomIndex]) + } + + // Convert to byte slice for shuffling + passwordBytes := []byte(passwordBuilder.String()) + + // Secure shuffling + for i := len(passwordBytes) - 1; i > 0; i-- { + j := secureRandomInt(i + 1) + + passwordBytes[i], passwordBytes[j] = passwordBytes[j], passwordBytes[i] + } + + return string(passwordBytes) +} diff --git a/internal/crypto/password_test.go b/internal/crypto/password_test.go new file mode 100644 index 000000000..289c9fe5b --- /dev/null +++ b/internal/crypto/password_test.go @@ -0,0 +1,178 @@ +package crypto + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestArgon2(t *testing.T) { + // all of these hash the `test` string with various parameters + + examples := []string{ + "$argon2i$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + "$argon2id$v=19$m=32,t=3,p=2$SFVpOWJ0eXhjRzVkdGN1RQ$RXnb8rh7LaDcn07xsssqqulZYXOM/EUCEFMVcAcyYVk", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "test")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) + } + + negativeExamples := []string{ + // 2d + "$argon2d$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // v=16 + "$argon2id$v=16$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // data + "$argon2id$v=19$m=16,t=2,p=1,data=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // keyid + "$argon2id$v=19$m=16,t=2,p=1,keyid=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // m larger than 32 bits + "$argon2id$v=19$m=4294967297,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // t larger than 32 bits + "$argon2id$v=19$m=16,t=4294967297,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // p larger than 8 bits + "$argon2id$v=19$m=16,t=2,p=256$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // salt not Base64 + "$argon2id$v=19$m=16,t=2,p=1$!!!$NfEnUOuUpb7F2fQkgFUG4g", + // hash not Base64 + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$!!!", + // salt empty + "$argon2id$v=19$m=16,t=2,p=1$$NfEnUOuUpb7F2fQkgFUG4g", + // hash empty + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) + } +} + +func TestGeneratePassword(t *testing.T) { + tests := []struct { + name string + requiredChars []string + length int + }{ + { + name: "Valid password generation", + requiredChars: []string{"ABC", "123", "@#$"}, + length: 12, + }, + { + name: "Empty required chars", + requiredChars: []string{}, + length: 8, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GeneratePassword(tt.requiredChars, tt.length) + + if len(got) != tt.length { + t.Errorf("GeneratePassword() returned password of length %d, want %d", len(got), tt.length) + } + + // Check if all required characters are present + for _, chars := range tt.requiredChars { + found := false + for _, c := range got { + if strings.ContainsRune(chars, c) { + found = true + break + } + } + if !found && len(chars) > 0 { + t.Errorf("GeneratePassword() missing required character from set %s", chars) + } + } + }) + } + + // Check for duplicates passwords + passwords := make(map[string]bool) + for i := 0; i < 30; i++ { + p := GeneratePassword([]string{"ABC", "123", "@#$"}, 30) + + if passwords[p] { + t.Errorf("GeneratePassword() generated duplicate password: %s", p) + } + passwords[p] = true + } +} + +func TestFirebaseScrypt(t *testing.T) { + // all of these use the `mytestpassword` string as the valid one + + examples := []string{ + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword1")) + } + + negativeExamples := []string{ + // v not 1 + "$fbscrypt$v=2,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n not 32 bits + "$fbscrypt$v=1,n=4294967297,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n is 0 + "$fbscrypt$v=1,n=0,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is not 64 bits + "$fbscrypt$v=1,n=14,r=18446744073709551617,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is 0 + "$fbscrypt$v=1,n=14,r=0,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is not 8 bits + "$fbscrypt$v=1,n=14,r=8,p=256,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is 0 + "$fbscrypt$v=1,n=14,r=8,p=0,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // hash is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$!!!", + // salt is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$!!!$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // signer key is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=!!!$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // salt separator is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=!!!,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } +} + +func TestBcrypt(t *testing.T) { + // all use the `test` password + + examples := []string{ + "$2y$04$mIJxfrCaEI3GukZe11CiXublhEFanu5.ododkll1WphfSp6pn4zIu", + "$2y$10$srNl09aPtc2qr.0Vl.NtjekJRt/NxRxYQm3qd3OvfcKsJgVnr6.Ve", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "test")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) + } + + negativeExamples := []string{ + "not-a-hash", + } + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) + } +} diff --git a/internal/crypto/utils.go b/internal/crypto/utils.go new file mode 100644 index 000000000..a6b38b8e8 --- /dev/null +++ b/internal/crypto/utils.go @@ -0,0 +1,9 @@ +package crypto + +func must[T any](a T, err error) T { + if err != nil { + panic(err) + } + + return a +} diff --git a/internal/crypto/utils_test.go b/internal/crypto/utils_test.go new file mode 100644 index 000000000..1aeeab80c --- /dev/null +++ b/internal/crypto/utils_test.go @@ -0,0 +1,14 @@ +package crypto + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestMust(t *testing.T) { + require.Panics(t, func() { + must(123, errors.New("panic")) + }) +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go new file mode 100644 index 000000000..1b881d36f --- /dev/null +++ b/internal/hooks/auth_hooks.go @@ -0,0 +1,220 @@ +package hooks + +import ( + "github.com/gofrs/uuid" + "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type HookType string + +const ( + PostgresHook HookType = "pg-functions" +) + +const ( + // In Miliseconds + DefaultTimeout = 2000 +) + +// Hook Names +const ( + HookRejection = "reject" +) + +type HTTPHookInput interface { + IsHTTPHook() +} + +type HookOutput interface { + IsError() bool + Error() string +} + +// TODO(joel): Move this to phone package +type SMS struct { + OTP string `json:"otp,omitempty"` + SMSType string `json:"sms_type,omitempty"` +} + +// #nosec +const MinimumViableTokenSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "aud": { + "type": ["string", "array"] + }, + "exp": { + "type": "integer" + }, + "jti": { + "type": "string" + }, + "iat": { + "type": "integer" + }, + "iss": { + "type": "string" + }, + "nbf": { + "type": "integer" + }, + "sub": { + "type": "string" + }, + "email": { + "type": "string" + }, + "phone": { + "type": "string" + }, + "app_metadata": { + "type": "object", + "additionalProperties": true + }, + "user_metadata": { + "type": "object", + "additionalProperties": true + }, + "role": { + "type": "string" + }, + "aal": { + "type": "string" + }, + "amr": { + "type": "array", + "items": { + "type": "object" + } + }, + "session_id": { + "type": "string" + } + }, + "required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"] +}` + +// AccessTokenClaims is a struct thats used for JWT claims +type AccessTokenClaims struct { + jwt.RegisteredClaims + Email string `json:"email"` + Phone string `json:"phone"` + AppMetaData map[string]interface{} `json:"app_metadata"` + UserMetaData map[string]interface{} `json:"user_metadata"` + Role string `json:"role"` + AuthenticatorAssuranceLevel string `json:"aal,omitempty"` + AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` + SessionId string `json:"session_id,omitempty"` + IsAnonymous bool `json:"is_anonymous"` +} + +type MFAVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + FactorType string `json:"factor_type"` + Valid bool `json:"valid"` +} + +type MFAVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` + HookError AuthHookError `json:"error"` +} + +type PasswordVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + Valid bool `json:"valid"` +} + +type PasswordVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` + ShouldLogoutUser bool `json:"should_logout_user"` + HookError AuthHookError `json:"error"` +} + +type CustomAccessTokenInput struct { + UserID uuid.UUID `json:"user_id"` + Claims *AccessTokenClaims `json:"claims"` + AuthenticationMethod string `json:"authentication_method"` +} + +type CustomAccessTokenOutput struct { + Claims map[string]interface{} `json:"claims"` + HookError AuthHookError `json:"error,omitempty"` +} + +type SendSMSInput struct { + User *models.User `json:"user,omitempty"` + SMS SMS `json:"sms,omitempty"` +} + +type SendSMSOutput struct { + HookError AuthHookError `json:"error,omitempty"` +} + +type SendEmailInput struct { + User *models.User `json:"user"` + EmailData mailer.EmailData `json:"email_data"` +} + +type SendEmailOutput struct { + HookError AuthHookError `json:"error,omitempty"` +} + +func (mf *MFAVerificationAttemptOutput) IsError() bool { + return mf.HookError.Message != "" +} + +func (mf *MFAVerificationAttemptOutput) Error() string { + return mf.HookError.Message +} + +func (p *PasswordVerificationAttemptOutput) IsError() bool { + return p.HookError.Message != "" +} + +func (p *PasswordVerificationAttemptOutput) Error() string { + return p.HookError.Message +} + +func (ca *CustomAccessTokenOutput) IsError() bool { + return ca.HookError.Message != "" +} + +func (ca *CustomAccessTokenOutput) Error() string { + return ca.HookError.Message +} + +func (cs *SendSMSOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *SendSMSOutput) Error() string { + return cs.HookError.Message +} + +func (cs *SendEmailOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *SendEmailOutput) Error() string { + return cs.HookError.Message +} + +type AuthHookError struct { + HTTPCode int `json:"http_code,omitempty"` + Message string `json:"message,omitempty"` +} + +func (a *AuthHookError) Error() string { + return a.Message +} + +const ( + DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." + DefaultPasswordHookRejectionMessage = "Further password verification attempts will be rejected." +) diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go new file mode 100644 index 000000000..1499960f5 --- /dev/null +++ b/internal/mailer/mailer.go @@ -0,0 +1,93 @@ +package mailer + +import ( + "fmt" + "net/http" + "net/url" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +// Mailer defines the interface a mailer must implement. +type Mailer interface { + InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error + ReauthenticateMail(r *http.Request, user *models.User, otp string) error + GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) +} + +type EmailParams struct { + Token string + Type string + RedirectTo string +} + +type EmailData struct { + Token string `json:"token"` + TokenHash string `json:"token_hash"` + RedirectTo string `json:"redirect_to"` + EmailActionType string `json:"email_action_type"` + SiteURL string `json:"site_url"` + TokenNew string `json:"token_new"` + TokenHashNew string `json:"token_hash_new"` +} + +// NewMailer returns a new gotrue mailer +func NewMailer(globalConfig *conf.GlobalConfiguration) Mailer { + from := globalConfig.SMTP.FromAddress() + u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) + + var mailClient MailClient + if globalConfig.SMTP.Host == "" { + logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) + mailClient = &noopMailClient{ + EmailValidator: newEmailValidator(globalConfig.Mailer), + } + } else { + mailClient = &MailmeMailer{ + Host: globalConfig.SMTP.Host, + Port: globalConfig.SMTP.Port, + User: globalConfig.SMTP.User, + Pass: globalConfig.SMTP.Pass, + LocalName: u.Hostname(), + From: from, + BaseURL: globalConfig.SiteURL, + Logger: logrus.StandardLogger(), + MailLogging: globalConfig.SMTP.LoggingEnabled, + EmailValidator: newEmailValidator(globalConfig.Mailer), + } + } + + return &TemplateMailer{ + SiteURL: globalConfig.SiteURL, + Config: globalConfig, + Mailer: mailClient, + } +} + +func withDefault(value, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + +func getPath(filepath string, params *EmailParams) (*url.URL, error) { + path := &url.URL{} + if filepath != "" { + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p + } + } + if params != nil { + path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) + } + return path, nil +} diff --git a/internal/mailer/mailer_test.go b/internal/mailer/mailer_test.go new file mode 100644 index 000000000..290d65dd0 --- /dev/null +++ b/internal/mailer/mailer_test.go @@ -0,0 +1,87 @@ +package mailer + +import ( + "net/url" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +var urlRegexp = regexp.MustCompile(`^https?://[^/]+`) + +func enforceRelativeURL(url string) string { + return urlRegexp.ReplaceAllString(url, "") +} + +func TestGetPath(t *testing.T) { + params := EmailParams{ + Token: "token", + Type: "signup", + RedirectTo: "https://example.com", + } + cases := []struct { + SiteURL string + Path string + Params *EmailParams + Expected string + }{ + { + SiteURL: "https://test.example.com", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/removedpath", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/", + Path: "/trailingslash/", + Params: nil, + Expected: "https://test.example.com/trailingslash/", + }, + { + SiteURL: "https://test.example.com", + Path: "f", + Params: ¶ms, + Expected: "https://test.example.com/f?token=token&type=signup&redirect_to=https://example.com", + }, + { + SiteURL: "https://test.example.com", + Path: "", + Params: ¶ms, + Expected: "https://test.example.com?token=token&type=signup&redirect_to=https://example.com", + }, + } + + for _, c := range cases { + u, err := url.ParseRequestURI(c.SiteURL) + assert.NoError(t, err, "error parsing URI request") + + path, err := getPath(c.Path, c.Params) + + assert.NoError(t, err) + assert.Equal(t, c.Expected, u.ResolveReference(path).String()) + } +} + +func TestRelativeURL(t *testing.T) { + cases := []struct { + URL string + Expected string + }{ + {"https://test.example.com", ""}, + {"http://test.example.com", ""}, + {"test.example.com", "test.example.com"}, + {"/some/path#fragment", "/some/path#fragment"}, + } + + for _, c := range cases { + res := enforceRelativeURL(c.URL) + assert.Equal(t, c.Expected, res, c.URL) + } +} diff --git a/internal/mailer/mailme.go b/internal/mailer/mailme.go new file mode 100644 index 000000000..20ff17740 --- /dev/null +++ b/internal/mailer/mailme.go @@ -0,0 +1,230 @@ +package mailer + +import ( + "bytes" + "context" + "errors" + "html/template" + "io" + "log" + "net/http" + "strings" + "sync" + "time" + + "gopkg.in/gomail.v2" + + "github.com/sirupsen/logrus" +) + +// TemplateRetries is the amount of time MailMe will try to fetch a URL before giving up +const TemplateRetries = 3 + +// TemplateExpiration is the time period that the template will be cached for +const TemplateExpiration = 10 * time.Second + +// MailmeMailer lets MailMe send templated mails +type MailmeMailer struct { + From string + Host string + Port int + User string + Pass string + BaseURL string + LocalName string + FuncMap template.FuncMap + cache *TemplateCache + Logger logrus.FieldLogger + MailLogging bool + EmailValidator *EmailValidator +} + +// Mail sends a templated mail. It will try to load the template from a URL, and +// otherwise fall back to the default +func (m *MailmeMailer) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, +) error { + if m.FuncMap == nil { + m.FuncMap = map[string]interface{}{} + } + if m.cache == nil { + m.cache = &TemplateCache{ + templates: map[string]*MailTemplate{}, + funcMap: m.FuncMap, + logger: m.Logger, + } + } + + if m.EmailValidator != nil { + if err := m.EmailValidator.Validate(ctx, to); err != nil { + return err + } + } + + tmp, err := template.New("Subject").Funcs(template.FuncMap(m.FuncMap)).Parse(subjectTemplate) + if err != nil { + return err + } + + subject := &bytes.Buffer{} + err = tmp.Execute(subject, templateData) + if err != nil { + return err + } + + body, err := m.MailBody(templateURL, defaultTemplate, templateData) + if err != nil { + return err + } + + mail := gomail.NewMessage() + mail.SetHeader("From", m.From) + mail.SetHeader("To", to) + mail.SetHeader("Subject", subject.String()) + + for k, v := range headers { + if v != nil { + mail.SetHeader(k, v...) + } + } + + mail.SetBody("text/html", body) + + dial := gomail.NewDialer(m.Host, m.Port, m.User, m.Pass) + if m.LocalName != "" { + dial.LocalName = m.LocalName + } + + if m.MailLogging { + defer func() { + fields := logrus.Fields{ + "event": "mail.send", + "mail_type": typ, + "mail_from": m.From, + "mail_to": to, + } + m.Logger.WithFields(fields).Info("mail.send") + }() + } + if err := dial.DialAndSend(mail); err != nil { + return err + } + return nil +} + +type MailTemplate struct { + tmp *template.Template + expiresAt time.Time +} + +type TemplateCache struct { + templates map[string]*MailTemplate + mutex sync.Mutex + funcMap template.FuncMap + logger logrus.FieldLogger +} + +func (t *TemplateCache) Get(url string) (*template.Template, error) { + cached, ok := t.templates[url] + if ok && (cached.expiresAt.Before(time.Now())) { + return cached.tmp, nil + } + data, err := t.fetchTemplate(url, TemplateRetries) + if err != nil { + return nil, err + } + return t.Set(url, data, TemplateExpiration) +} + +func (t *TemplateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { + parsed, err := template.New(key).Funcs(t.funcMap).Parse(value) + if err != nil { + return nil, err + } + + cached := &MailTemplate{ + tmp: parsed, + expiresAt: time.Now().Add(expirationTime), + } + t.mutex.Lock() + t.templates[key] = cached + t.mutex.Unlock() + return parsed, nil +} + +func (t *TemplateCache) fetchTemplate(url string, triesLeft int) (string, error) { + client := &http.Client{ + Timeout: 10 * time.Second, + } + + resp, err := client.Get(url) + if err != nil && triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode == 200 { // OK + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil && triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + if err != nil { + return "", err + } + return string(bodyBytes), err + } + if triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + return "", errors.New("mailer: unable to fetch mail template") +} + +func (m *MailmeMailer) MailBody(url string, defaultTemplate string, data map[string]interface{}) (string, error) { + if m.FuncMap == nil { + m.FuncMap = map[string]interface{}{} + } + if m.cache == nil { + m.cache = &TemplateCache{templates: map[string]*MailTemplate{}, funcMap: m.FuncMap} + } + + var temp *template.Template + var err error + + if url != "" { + var absoluteURL string + if strings.HasPrefix(url, "http") { + absoluteURL = url + } else { + absoluteURL = m.BaseURL + url + } + temp, err = m.cache.Get(absoluteURL) + if err != nil { + log.Printf("Error loading template from %v: %v\n", url, err) + } + } + + if temp == nil { + cached, ok := m.cache.templates[url] + if ok { + temp = cached.tmp + } else { + temp, err = m.cache.Set(url, defaultTemplate, 0) + if err != nil { + return "", err + } + } + } + + buf := &bytes.Buffer{} + err = temp.Execute(buf, data) + if err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/internal/mailer/noop.go b/internal/mailer/noop.go new file mode 100644 index 000000000..0e0e3bfcb --- /dev/null +++ b/internal/mailer/noop.go @@ -0,0 +1,28 @@ +package mailer + +import ( + "context" + "errors" +) + +type noopMailClient struct { + EmailValidator *EmailValidator +} + +func (m *noopMailClient) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, +) error { + if to == "" { + return errors.New("to field cannot be empty") + } + if m.EmailValidator != nil { + if err := m.EmailValidator.Validate(ctx, to); err != nil { + return err + } + } + return nil +} diff --git a/internal/mailer/template.go b/internal/mailer/template.go new file mode 100644 index 000000000..59a485450 --- /dev/null +++ b/internal/mailer/template.go @@ -0,0 +1,420 @@ +package mailer + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type MailRequest struct { + To string + SubjectTemplate string + TemplateURL string + DefaultTemplate string + TemplateData map[string]interface{} + Headers map[string][]string + Type string +} + +type MailClient interface { + Mail( + ctx context.Context, + to string, + subjectTemplate string, + templateURL string, + defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, + ) error +} + +// TemplateMailer will send mail and use templates from the site for easy mail styling +type TemplateMailer struct { + SiteURL string + Config *conf.GlobalConfiguration + Mailer MailClient +} + +func encodeRedirectURL(referrerURL string) string { + if len(referrerURL) > 0 { + if strings.ContainsAny(referrerURL, "&=#") { + // if the string contains &, = or # it has not been URL + // encoded by the caller, which means it should be URL + // encoded by us otherwise, it should be taken as-is + referrerURL = url.QueryEscape(referrerURL) + } + } + return referrerURL +} + +const ( + SignupVerification = "signup" + RecoveryVerification = "recovery" + InviteVerification = "invite" + MagicLinkVerification = "magiclink" + EmailChangeVerification = "email_change" + EmailOTPVerification = "email" + EmailChangeCurrentVerification = "email_change_current" + EmailChangeNewVerification = "email_change_new" + ReauthenticationVerification = "reauthentication" +) + +const defaultInviteMail = `

You have been invited

+ +

You have been invited to create a user on {{ .SiteURL }}. Follow this link to accept the invite:

+

Accept the invite

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultConfirmationMail = `

Confirm your email

+ +

Follow this link to confirm your email:

+

Confirm your email address

+

Alternatively, enter the code: {{ .Token }}

+` + +const defaultRecoveryMail = `

Reset password

+ +

Follow this link to reset the password for your user:

+

Reset password

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultMagicLinkMail = `

Magic Link

+ +

Follow this link to login:

+

Log In

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultEmailChangeMail = `

Confirm email address change

+ +

Follow this link to confirm the update of your email address from {{ .Email }} to {{ .NewEmail }}:

+

Change email address

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultReauthenticateMail = `

Confirm reauthentication

+ +

Enter the code: {{ .Token }}

` + +func (m *TemplateMailer) Headers(messageType string) map[string][]string { + originalHeaders := m.Config.SMTP.NormalizedHeaders() + + if originalHeaders == nil { + return nil + } + + headers := make(map[string][]string, len(originalHeaders)) + + for header, values := range originalHeaders { + replacedValues := make([]string, 0, len(values)) + + if header == "" { + continue + } + + for _, value := range values { + if value == "" { + continue + } + + // TODO: in the future, use a templating engine to add more contextual data available to headers + if strings.Contains(value, "$messageType") { + replacedValues = append(replacedValues, strings.ReplaceAll(value, "$messageType", messageType)) + } else { + replacedValues = append(replacedValues, value) + } + } + + headers[header] = replacedValues + } + + return headers +} + +// InviteMail sends a invite mail to a new user +func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + Token: user.ConfirmationToken, + Type: "invite", + RedirectTo: referrerURL, + }) + + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.ConfirmationToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited"), + m.Config.Mailer.Templates.Invite, + defaultInviteMail, + data, + m.Headers("invite"), + "invite", + ) +} + +// ConfirmationMail sends a signup confirmation mail to a new user +func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + Token: user.ConfirmationToken, + Type: "signup", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.ConfirmationToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email"), + m.Config.Mailer.Templates.Confirmation, + defaultConfirmationMail, + data, + m.Headers("confirm"), + "confirm", + ) +} + +// ReauthenticateMail sends a reauthentication mail to an authenticated user +func (m *TemplateMailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "Email": user.Email, + "Token": otp, + "Data": user.UserMetaData, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Reauthentication, "Confirm reauthentication"), + m.Config.Mailer.Templates.Reauthentication, + defaultReauthenticateMail, + data, + m.Headers("reauthenticate"), + "reauthenticate", + ) +} + +// EmailChangeMail sends an email change confirmation mail to a user +func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { + type Email struct { + Address string + Otp string + TokenHash string + Subject string + Template string + } + emails := []Email{ + { + Address: user.EmailChange, + Otp: otpNew, + TokenHash: user.EmailChangeTokenNew, + Subject: withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), + Template: m.Config.Mailer.Templates.EmailChange, + }, + } + + currentEmail := user.GetEmail() + if m.Config.Mailer.SecureEmailChangeEnabled && currentEmail != "" { + emails = append(emails, Email{ + Address: currentEmail, + Otp: otpCurrent, + TokenHash: user.EmailChangeTokenCurrent, + Subject: withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Email Address"), + Template: m.Config.Mailer.Templates.EmailChange, + }) + } + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + errors := make(chan error, len(emails)) + for _, email := range emails { + path, err := getPath( + m.Config.Mailer.URLPaths.EmailChange, + &EmailParams{ + Token: email.TokenHash, + Type: "email_change", + RedirectTo: referrerURL, + }, + ) + if err != nil { + return err + } + go func(address, token, tokenHash, template string) { + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.GetEmail(), + "NewEmail": user.EmailChange, + "Token": token, + "TokenHash": tokenHash, + "SendingTo": address, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + errors <- m.Mailer.Mail( + ctx, + address, + withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), + template, + defaultEmailChangeMail, + data, + m.Headers("email_change"), + "email_change", + ) + }(email.Address, email.Otp, email.TokenHash, email.Template) + } + + for i := 0; i < len(emails); i++ { + e := <-errors + if e != nil { + return e + } + } + return nil +} + +// RecoveryMail sends a password recovery mail +func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "recovery", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.RecoveryToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password"), + m.Config.Mailer.Templates.Recovery, + defaultRecoveryMail, + data, + m.Headers("recovery"), + "recovery", + ) +} + +// MagicLinkMail sends a login link mail +func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "magiclink", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.RecoveryToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link"), + m.Config.Mailer.Templates.MagicLink, + defaultMagicLinkMail, + data, + m.Headers("magiclink"), + "magiclink", + ) +} + +// GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. +func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { + var err error + var path *url.URL + + switch actionType { + case "magiclink": + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "magiclink", + RedirectTo: referrerURL, + }) + case "recovery": + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "recovery", + RedirectTo: referrerURL, + }) + case "invite": + path, err = getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + Token: user.ConfirmationToken, + Type: "invite", + RedirectTo: referrerURL, + }) + case "signup": + path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + Token: user.ConfirmationToken, + Type: "signup", + RedirectTo: referrerURL, + }) + case "email_change_current": + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + Token: user.EmailChangeTokenCurrent, + Type: "email_change", + RedirectTo: referrerURL, + }) + case "email_change_new": + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + Token: user.EmailChangeTokenNew, + Type: "email_change", + RedirectTo: referrerURL, + }) + default: + return "", fmt.Errorf("invalid email action link type: %s", actionType) + } + if err != nil { + return "", err + } + return externalURL.ResolveReference(path).String(), nil +} diff --git a/internal/mailer/template_test.go b/internal/mailer/template_test.go new file mode 100644 index 000000000..f8fcd7417 --- /dev/null +++ b/internal/mailer/template_test.go @@ -0,0 +1,65 @@ +package mailer + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestTemplateHeaders(t *testing.T) { + cases := []struct { + from string + typ string + exp map[string][]string + }{ + { + from: `{"x-supabase-project-ref": ["abcjrhohrqmvcpjpsyzc"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "x-supabase-project-ref": {"abcjrhohrqmvcpjpsyzc"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"]}`, + typ: "TEST-MESSAGE-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc TEST-MESSAGE-TYPE"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc OTHER-TYPE"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"], "x-supabase-project-ref": ["abcjrhohrqmvcpjpsyzc"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc OTHER-TYPE"}, + "x-supabase-project-ref": {"abcjrhohrqmvcpjpsyzc"}, + }, + }, + } + for _, tc := range cases { + mailer := TemplateMailer{ + Config: &conf.GlobalConfiguration{ + SMTP: conf.SMTPConfiguration{ + Headers: tc.from, + }, + }, + } + require.NoError(t, mailer.Config.SMTP.Validate()) + + hdrs := mailer.Headers(tc.typ) + require.Equal(t, hdrs, tc.exp) + } +} diff --git a/internal/mailer/validate.go b/internal/mailer/validate.go new file mode 100644 index 000000000..817d16894 --- /dev/null +++ b/internal/mailer/validate.go @@ -0,0 +1,315 @@ +package mailer + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/mail" + "strings" + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/sync/errgroup" +) + +var invalidEmailMap = map[string]bool{ + + // People type these often enough to be special cased. + "test@gmail.com": true, + "example@gmail.com": true, + "someone@gmail.com": true, + "test@email.com": true, +} + +var invalidHostSuffixes = []string{ + + // These are a directly from Section 2 of RFC2606[1]. + // + // [1] https://www.rfc-editor.org/rfc/rfc2606.html#section-2 + ".test", + ".example", + ".invalid", + ".local", + ".localhost", +} + +var invalidHostMap = map[string]bool{ + + // These exist here too for when they are typed as "test@test" + "test": true, + "example": true, + "invalid": true, + "local": true, + "localhost": true, + + // These are commonly typed and have DNS records which cause a + // large enough volume of bounce backs to special case. + "test.com": true, + "example.com": true, + "example.net": true, + "example.org": true, + + // Hundreds of typos per day for this. + "gamil.com": true, + + // These are not email providers, but people often use them. + "anonymous.com": true, + "email.com": true, +} + +const ( + validateEmailTimeout = 3 * time.Second +) + +var ( + // We use the default resolver for this. + validateEmailResolver net.Resolver +) + +var ( + ErrInvalidEmailAddress = errors.New("invalid_email_address") + ErrInvalidEmailFormat = errors.New("invalid_email_format") + ErrInvalidEmailDNS = errors.New("invalid_email_dns") + ErrInvalidEmailMX = errors.New("invalid_email_mx") +) + +type EmailValidator struct { + extended bool + serviceURL string + serviceHeaders map[string][]string + blockedMXRecords map[string]bool +} + +func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { + return &EmailValidator{ + extended: mc.EmailValidationExtended, + serviceURL: mc.EmailValidationServiceURL, + serviceHeaders: mc.GetEmailValidationServiceHeaders(), + blockedMXRecords: mc.GetEmailValidationBlockedMXRecords(), + } +} + +func (ev *EmailValidator) isExtendedEnabled() bool { return ev.extended } +func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } + +// Validate performs validation on the given email. +// +// When extended is true, returns a nil error in all cases but the following: +// - `email` cannot be parsed by mail.ParseAddress +// - `email` has a domain with no DNS configured +// +// When serviceURL AND serviceKey are non-empty strings it uses the remote +// service to determine if the email is valid. +func (ev *EmailValidator) Validate(ctx context.Context, email string) error { + if !ev.isExtendedEnabled() && !ev.isServiceEnabled() { + return nil + } + + // One of the two validation methods are enabled, set a timeout. + ctx, cancel := context.WithTimeout(ctx, validateEmailTimeout) + defer cancel() + + // Easier control flow here to always use errgroup, it has very little + // overhad in comparison to the network calls it makes. The reason + // we run both checks concurrently is to tighten the timeout without + // potentially missing a call to the validation service due to a + // dns timeout or something more nefarious like a honeypot dns entry. + g := new(errgroup.Group) + + // Validate the static rules first to prevent round trips on bad emails + // and to parse the host ahead of time. + if ev.isExtendedEnabled() { + + // First validate static checks such as format, known invalid hosts + // and any other network free checks. Running this check before we + // call the service will help reduce the number of calls with known + // invalid emails. + host, err := ev.validateStatic(email) + if err != nil { + return err + } + + // Start the goroutine to validate the host. + g.Go(func() error { return ev.validateHost(ctx, host) }) + } + + // If the service check is enabled we start a goroutine to run + // that check as well. + if ev.isServiceEnabled() { + g.Go(func() error { return ev.validateService(ctx, email) }) + } + return g.Wait() +} + +// validateStatic will validate the format and do the static checks before +// returning the host portion of the email. +func (ev *EmailValidator) validateStatic(email string) (string, error) { + if !ev.isExtendedEnabled() { + return "", nil + } + + ea, err := mail.ParseAddress(email) + if err != nil { + return "", ErrInvalidEmailFormat + } + + i := strings.LastIndex(ea.Address, "@") + if i == -1 { + return "", ErrInvalidEmailFormat + } + + // few static lookups that are typed constantly and known to be invalid. + if invalidEmailMap[email] { + return "", ErrInvalidEmailAddress + } + + host := email[i+1:] + if invalidHostMap[host] { + return "", ErrInvalidEmailDNS + } + + for i := range invalidHostSuffixes { + if strings.HasSuffix(host, invalidHostSuffixes[i]) { + return "", ErrInvalidEmailDNS + } + } + + name := email[:i] + if err := ev.validateProviders(name, host); err != nil { + return "", err + } + return host, nil +} + +func (ev *EmailValidator) validateService(ctx context.Context, email string) error { + if !ev.isServiceEnabled() { + return nil + } + + reqObject := struct { + EmailAddress string `json:"email"` + }{email} + + reqData, err := json.Marshal(&reqObject) + if err != nil { + return nil + } + + rdr := bytes.NewReader(reqData) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ev.serviceURL, rdr) + if err != nil { + return nil + } + req.Header.Set("Content-Type", "application/json") + for name, vals := range ev.serviceHeaders { + for _, val := range vals { + req.Header.Set(name, val) + } + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil + } + defer res.Body.Close() + + resObject := struct { + Valid *bool `json:"valid"` + }{} + + if res.StatusCode/100 != 2 { + // we ignore the error here just in case the service is down + return nil + } + + dec := json.NewDecoder(io.LimitReader(res.Body, 1<<5)) + if err := dec.Decode(&resObject); err != nil { + return nil + } + + // If the object did not contain a valid key we consider the check as + // failed. We _must_ get a valid JSON response with a "valid" field. + if resObject.Valid == nil || *resObject.Valid { + return nil + } + + return ErrInvalidEmailAddress +} + +func (ev *EmailValidator) validateProviders(name, host string) error { + switch host { + case "gmail.com": + // Based on a sample of internal data, this reduces the number of + // bounced emails by 23%. Gmail documentation specifies that the + // min user name length is 6 characters. There may be some accounts + // from early gmail beta with shorter email addresses, but I think + // this reduces bounce rates enough to be worth adding for now. + if len(name) < 6 { + return ErrInvalidEmailAddress + } + } + return nil +} + +func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { + mxs, err := validateEmailResolver.LookupMX(ctx, host) + if !isHostNotFound(err) { + return ev.validateMXRecords(mxs, nil) + } + + hosts, err := validateEmailResolver.LookupHost(ctx, host) + if !isHostNotFound(err) { + return ev.validateMXRecords(nil, hosts) + } + + // No addrs or mx records were found + return ErrInvalidEmailDNS +} + +func (ev *EmailValidator) validateMXRecords(mxs []*net.MX, hosts []string) error { + for _, mx := range mxs { + if ev.blockedMXRecords[mx.Host] { + return ErrInvalidEmailMX + } + } + for _, host := range hosts { + if ev.blockedMXRecords[host] { + return ErrInvalidEmailMX + } + } + return nil +} + +func isHostNotFound(err error) bool { + if err == nil { + // We had no err, so we treat it as valid. We don't check the mx records + // because RFC 5321 specifies that if an empty list of MX's are returned + // the host should be treated as the MX[1]. + // + // See section 2 and 3 of: https://www.rfc-editor.org/rfc/rfc2606 + // [1] https://www.rfc-editor.org/rfc/rfc5321.html#section-5.1 + return false + } + + // No names present, we will try to get a positive assertion that the + // domain is not configured to receive email. + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + // We will be unable to determine with absolute certainy the email was + // invalid so we will err on the side of caution and return nil. + return false + } + + // The type of err is dnsError, inspect it to see if we can be certain + // the domain has no mx records currently. For this we require that + // the error was not temporary or a timeout. If those are both false + // we trust the value in IsNotFound. + if !dnsError.IsTemporary && !dnsError.IsTimeout && dnsError.IsNotFound { + return true + } + return false +} diff --git a/internal/mailer/validate_test.go b/internal/mailer/validate_test.go new file mode 100644 index 000000000..e9bae8fef --- /dev/null +++ b/internal/mailer/validate_test.go @@ -0,0 +1,297 @@ +package mailer + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestEmalValidatorService(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Second*60) + defer cancel() + + testResVal := new(atomic.Value) + testResVal.Store(`{"valid": true}`) + + testHdrsVal := new(atomic.Value) + testHdrsVal.Store(map[string]string{"apikey": "test"}) + + // testHeaders := map[string][]string{"apikey": []string{"test"}} + testHeaders := `{"apikey": ["test"]}` + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("apikey") + if key == "" { + fmt.Fprintln(w, `{"error": true}`) + return + } + + fmt.Fprintln(w, testResVal.Load().(string)) + })) + defer ts.Close() + + // Return nil err from service + // when svc and extended checks both report email as valid + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "chris.stockton@supabase.io") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return nil err from service when + // extended is disabled for a known invalid address + // service reports valid + { + testResVal.Store(`{"valid": true}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return nil err from service when + // extended is disabled for a known invalid address + // service is disabled for a known invalid address + { + testResVal.Store(`{"valid": false}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return err from service when + // extended reports invalid + // service is disabled for a known invalid address + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports invalid + // service reports valid + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports valid + // service reports invalid + { + testResVal.Store(`{"valid": false}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "chris.stockton@supabase.io") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports invalid + // service reports invalid + { + testResVal.Store(`{"valid": false}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } +} + +func TestValidateEmailExtended(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Second*60) + defer cancel() + + cases := []struct { + email string + timeout time.Duration + err string + }{ + // valid (has mx record) + {email: "a@supabase.io"}, + {email: "support@supabase.io"}, + {email: "chris.stockton@supabase.io"}, + + // bad format + {email: "", err: "invalid_email_format"}, + {email: "io", err: "invalid_email_format"}, + {email: "supabase.io", err: "invalid_email_format"}, + {email: "@supabase.io", err: "invalid_email_format"}, + {email: "test@.supabase.io", err: "invalid_email_format"}, + + // invalid: valid mx records, but invalid and often typed + // (invalidEmailMap) + {email: "test@email.com", err: "invalid_email_address"}, + {email: "test@gmail.com", err: "invalid_email_address"}, + {email: "test@test.com", err: "invalid_email_dns"}, + + // very common typo + {email: "test@gamil.com", err: "invalid_email_dns"}, + + // invalid: valid mx records, but invalid and often typed + // (invalidHostMap) + {email: "a@example.com", err: "invalid_email_dns"}, + {email: "a@example.net", err: "invalid_email_dns"}, + {email: "a@example.org", err: "invalid_email_dns"}, + + // invalid: no mx records + {email: "a@test", err: "invalid_email_dns"}, + {email: "test@local", err: "invalid_email_dns"}, + {email: "test@test.local", err: "invalid_email_dns"}, + {email: "test@example", err: "invalid_email_dns"}, + {email: "test@invalid", err: "invalid_email_dns"}, + + // valid but not actually valid and typed a lot + {email: "a@invalid", err: "invalid_email_dns"}, + {email: "a@a.invalid", err: "invalid_email_dns"}, + {email: "test@invalid", err: "invalid_email_dns"}, + + // various invalid emails + {email: "test@test.localhost", err: "invalid_email_dns"}, + {email: "test@invalid.example.com", err: "invalid_email_dns"}, + {email: "test@no.such.email.host.supabase.io", err: "invalid_email_dns"}, + + // test blocked mx records + {email: "test@hotmail.com", err: "invalid_email_mx"}, + + // this low timeout should simulate a dns timeout, which should + // not be treated as an invalid email. + {email: "validemail@probablyaaaaaaaanotarealdomain.com", + timeout: time.Millisecond}, + + // likewise for a valid email + {email: "support@supabase.io", timeout: time.Millisecond}, + } + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + EmailValidationBlockedMX: `["hotmail-com.olc.protection.outlook.com"]`, + } + + // Ensure the BlockedMX transformation occurs by calling Validate + if err := cfg.Validate(); err != nil { + t.Fatalf("failed to validate MailerConfiguration: %v", err) + } + + ev := newEmailValidator(cfg) + + for idx, tc := range cases { + func(timeout time.Duration) { + if timeout == 0 { + timeout = validateEmailTimeout + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + now := time.Now() + err := ev.Validate(ctx, tc.email) + dur := time.Since(now) + if max := timeout + (time.Millisecond * 50); max < dur { + t.Fatal("timeout was not respected") + } + + t.Logf("tc #%v - email %q", idx, tc.email) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + return + } + require.NoError(t, err) + + }(tc.timeout) + } +} diff --git a/metering/record.go b/internal/metering/record.go similarity index 73% rename from metering/record.go rename to internal/metering/record.go index 2ea469372..d9f9c5c2e 100644 --- a/metering/record.go +++ b/internal/metering/record.go @@ -7,11 +7,11 @@ import ( var logger = logrus.StandardLogger().WithField("metering", true) -func RecordLogin(loginType string, userID, instanceID uuid.UUID) { +func RecordLogin(loginType string, userID uuid.UUID) { logger.WithFields(logrus.Fields{ "action": "login", "login_method": loginType, - "instance_id": instanceID.String(), + "instance_id": uuid.Nil.String(), "user_id": userID.String(), }).Info("Login") } diff --git a/internal/models/amr.go b/internal/models/amr.go new file mode 100644 index 000000000..fdfd88361 --- /dev/null +++ b/internal/models/amr.go @@ -0,0 +1,43 @@ +package models + +import ( + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/storage" +) + +type AMRClaim struct { + ID uuid.UUID `json:"id" db:"id"` + SessionID uuid.UUID `json:"session_id" db:"session_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + AuthenticationMethod *string `json:"authentication_method" db:"authentication_method"` +} + +func (AMRClaim) TableName() string { + tableName := "mfa_amr_claims" + return tableName +} + +func (cl *AMRClaim) IsAAL2Claim() bool { + return *cl.AuthenticationMethod == TOTPSignIn.String() || *cl.AuthenticationMethod == MFAPhone.String() || *cl.AuthenticationMethod == MFAWebAuthn.String() +} + +func AddClaimToSession(tx *storage.Connection, sessionId uuid.UUID, authenticationMethod AuthenticationMethod) error { + id := uuid.Must(uuid.NewV4()) + + currentTime := time.Now() + return tx.RawQuery("INSERT INTO "+(&pop.Model{Value: AMRClaim{}}).TableName()+ + `(id, session_id, created_at, updated_at, authentication_method) values (?, ?, ?, ?, ?) + ON CONFLICT ON CONSTRAINT mfa_amr_claims_session_id_authentication_method_pkey + DO UPDATE SET updated_at = ?;`, id, sessionId, currentTime, currentTime, authenticationMethod.String(), currentTime).Exec() +} + +func (a *AMRClaim) GetAuthenticationMethod() string { + if a.AuthenticationMethod == nil { + return "" + } + return *(a.AuthenticationMethod) +} diff --git a/internal/models/audit_log_entry.go b/internal/models/audit_log_entry.go new file mode 100644 index 000000000..5bbc9b0b6 --- /dev/null +++ b/internal/models/audit_log_entry.go @@ -0,0 +1,166 @@ +package models + +import ( + "bytes" + "fmt" + "net/http" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +type AuditAction string +type auditLogType string + +const ( + LoginAction AuditAction = "login" + LogoutAction AuditAction = "logout" + InviteAcceptedAction AuditAction = "invite_accepted" + UserSignedUpAction AuditAction = "user_signedup" + UserInvitedAction AuditAction = "user_invited" + UserDeletedAction AuditAction = "user_deleted" + UserModifiedAction AuditAction = "user_modified" + UserRecoveryRequestedAction AuditAction = "user_recovery_requested" + UserReauthenticateAction AuditAction = "user_reauthenticate_requested" + UserConfirmationRequestedAction AuditAction = "user_confirmation_requested" + UserRepeatedSignUpAction AuditAction = "user_repeated_signup" + UserUpdatePasswordAction AuditAction = "user_updated_password" + TokenRevokedAction AuditAction = "token_revoked" + TokenRefreshedAction AuditAction = "token_refreshed" + GenerateRecoveryCodesAction AuditAction = "generate_recovery_codes" + EnrollFactorAction AuditAction = "factor_in_progress" + UnenrollFactorAction AuditAction = "factor_unenrolled" + CreateChallengeAction AuditAction = "challenge_created" + VerifyFactorAction AuditAction = "verification_attempted" + DeleteFactorAction AuditAction = "factor_deleted" + DeleteRecoveryCodesAction AuditAction = "recovery_codes_deleted" + UpdateFactorAction AuditAction = "factor_updated" + MFACodeLoginAction AuditAction = "mfa_code_login" + IdentityUnlinkAction AuditAction = "identity_unlinked" + + account auditLogType = "account" + team auditLogType = "team" + token auditLogType = "token" + user auditLogType = "user" + factor auditLogType = "factor" + recoveryCodes auditLogType = "recovery_codes" +) + +var ActionLogTypeMap = map[AuditAction]auditLogType{ + LoginAction: account, + LogoutAction: account, + InviteAcceptedAction: account, + UserSignedUpAction: team, + UserInvitedAction: team, + UserDeletedAction: team, + TokenRevokedAction: token, + TokenRefreshedAction: token, + UserModifiedAction: user, + UserRecoveryRequestedAction: user, + UserConfirmationRequestedAction: user, + UserRepeatedSignUpAction: user, + UserUpdatePasswordAction: user, + GenerateRecoveryCodesAction: user, + EnrollFactorAction: factor, + UnenrollFactorAction: factor, + CreateChallengeAction: factor, + VerifyFactorAction: factor, + DeleteFactorAction: factor, + UpdateFactorAction: factor, + MFACodeLoginAction: factor, + DeleteRecoveryCodesAction: recoveryCodes, +} + +// AuditLogEntry is the database model for audit log entries. +type AuditLogEntry struct { + ID uuid.UUID `json:"id" db:"id"` + Payload JSONMap `json:"payload" db:"payload"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + IPAddress string `json:"ip_address" db:"ip_address"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func (AuditLogEntry) TableName() string { + tableName := "audit_log_entries" + return tableName +} + +func NewAuditLogEntry(r *http.Request, tx *storage.Connection, actor *User, action AuditAction, ipAddress string, traits map[string]interface{}) error { + id := uuid.Must(uuid.NewV4()) + + username := actor.GetEmail() + + if actor.GetPhone() != "" { + username = actor.GetPhone() + } + + payload := map[string]interface{}{ + "actor_id": actor.ID, + "actor_via_sso": actor.IsSSOUser, + "actor_username": username, + "action": action, + "log_type": ActionLogTypeMap[action], + } + l := AuditLogEntry{ + ID: id, + Payload: JSONMap(payload), + IPAddress: ipAddress, + } + + observability.LogEntrySetFields(r, logrus.Fields{ + "auth_event": logrus.Fields(payload), + }) + + if name, ok := actor.UserMetaData["full_name"]; ok { + l.Payload["actor_name"] = name + } + + if traits != nil { + l.Payload["traits"] = traits + } + + if err := tx.Create(&l); err != nil { + return errors.Wrap(err, "Database error creating audit log entry") + } + + return nil +} + +func FindAuditLogEntries(tx *storage.Connection, filterColumns []string, filterValue string, pageParams *Pagination) ([]*AuditLogEntry, error) { + q := tx.Q().Order("created_at desc").Where("instance_id = ?", uuid.Nil) + + if len(filterColumns) > 0 && filterValue != "" { + lf := "%" + filterValue + "%" + + builder := bytes.NewBufferString("(") + values := make([]interface{}, len(filterColumns)) + + for idx, col := range filterColumns { + builder.WriteString(fmt.Sprintf("payload->>'%s' ILIKE ?", col)) + values[idx] = lf + + if idx+1 < len(filterColumns) { + builder.WriteString(" OR ") + } + } + builder.WriteString(")") + + q = q.Where(builder.String(), values...) + } + + logs := []*AuditLogEntry{} + var err error + if pageParams != nil { + err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&logs) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + } else { + err = q.All(&logs) + } + + return logs, err +} diff --git a/internal/models/challenge.go b/internal/models/challenge.go new file mode 100644 index 000000000..3de5b4d02 --- /dev/null +++ b/internal/models/challenge.go @@ -0,0 +1,124 @@ +package models + +import ( + "database/sql/driver" + "fmt" + + "encoding/json" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "time" +) + +type Challenge struct { + ID uuid.UUID `json:"challenge_id" db:"id"` + FactorID uuid.UUID `json:"factor_id" db:"factor_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + VerifiedAt *time.Time `json:"verified_at,omitempty" db:"verified_at"` + IPAddress string `json:"ip_address" db:"ip_address"` + Factor *Factor `json:"factor,omitempty" belongs_to:"factor"` + OtpCode string `json:"otp_code,omitempty" db:"otp_code"` + WebAuthnSessionData *WebAuthnSessionData `json:"web_authn_session_data,omitempty" db:"web_authn_session_data"` +} + +type WebAuthnSessionData struct { + *webauthn.SessionData +} + +func (s *WebAuthnSessionData) Scan(value interface{}) error { + if value == nil { + s.SessionData = nil + return nil + } + + // Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte + var data []byte + switch v := value.(type) { + case []byte: + data = v + case string: + data = []byte(v) + default: + panic(fmt.Sprintf("unsupported type for web_authn_session_data: %T", value)) + } + + if len(data) == 0 { + s.SessionData = nil + return nil + } + if s.SessionData == nil { + s.SessionData = &webauthn.SessionData{} + } + return json.Unmarshal(data, s.SessionData) + +} + +func (s *WebAuthnSessionData) Value() (driver.Value, error) { + if s == nil || s.SessionData == nil { + return nil, nil + } + return json.Marshal(s.SessionData) +} + +func (ws *WebAuthnSessionData) ToChallenge(factorID uuid.UUID, ipAddress string) *Challenge { + id := uuid.Must(uuid.NewV4()) + return &Challenge{ + ID: id, + FactorID: factorID, + IPAddress: ipAddress, + WebAuthnSessionData: &WebAuthnSessionData{ + ws.SessionData, + }, + } + +} + +func (Challenge) TableName() string { + tableName := "mfa_challenges" + return tableName +} + +// Update the verification timestamp +func (c *Challenge) Verify(tx *storage.Connection) error { + now := time.Now() + c.VerifiedAt = &now + return tx.UpdateOnly(c, "verified_at") +} + +func (c *Challenge) HasExpired(expiryDuration float64) bool { + return time.Now().After(c.GetExpiryTime(expiryDuration)) +} + +func (c *Challenge) GetExpiryTime(expiryDuration float64) time.Time { + return c.CreatedAt.Add(time.Second * time.Duration(expiryDuration)) +} + +func (c *Challenge) SetOtpCode(otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) error { + c.OtpCode = otpCode + if encrypt { + es, err := crypto.NewEncryptedString(c.ID.String(), []byte(otpCode), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + c.OtpCode = es.String() + } + return nil + +} + +func (c *Challenge) GetOtpCode(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) { + if es := crypto.ParseEncryptedString(c.OtpCode); es != nil { + bytes, err := es.Decrypt(c.ID.String(), decryptionKeys) + if err != nil { + return "", false, err + } + + return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil + } + + return c.OtpCode, encrypt, nil + +} diff --git a/internal/models/cleanup.go b/internal/models/cleanup.go new file mode 100644 index 000000000..9669c8d4b --- /dev/null +++ b/internal/models/cleanup.go @@ -0,0 +1,136 @@ +package models + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + + "go.opentelemetry.io/otel/attribute" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +type Cleaner interface { + Clean(*storage.Connection) (int, error) +} + +type Cleanup struct { + cleanupStatements []string + + // cleanupNext holds an atomically incrementing value that determines which of + // the cleanupStatements will be run next. + cleanupNext uint32 + + // cleanupAffectedRows tracks an OpenTelemetry metric on the total number of + // cleaned up rows. + cleanupAffectedRows atomic.Int64 +} + +func NewCleanup(config *conf.GlobalConfiguration) *Cleanup { + tableUsers := User{}.TableName() + tableRefreshTokens := RefreshToken{}.TableName() + tableSessions := Session{}.TableName() + tableRelayStates := SAMLRelayState{}.TableName() + tableFlowStates := FlowState{}.TableName() + tableMFAChallenges := Challenge{}.TableName() + tableMFAFactors := Factor{}.TableName() + + c := &Cleanup{} + + // These statements intentionally use SELECT ... FOR UPDATE SKIP LOCKED + // as this makes sure that only rows that are not being used in another + // transaction are deleted. These deletes are thus very quick and + // efficient, as they don't wait on other transactions. + c.cleanupStatements = append(c.cleanupStatements, + fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens), + fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens), + // sessions are deleted after 72 hours to allow refresh tokens + // to be deleted piecemeal; 10 at once so that cascades don't + // overwork the database + fmt.Sprintf("delete from %q where id in (select id from %q where not_after < now() - interval '72 hours' limit 10 for update skip locked);", tableSessions, tableSessions), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRelayStates, tableRelayStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableFlowStates, tableFlowStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableMFAChallenges, tableMFAChallenges), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' and status = 'unverified' limit 100 for update skip locked);", tableMFAFactors, tableMFAFactors), + ) + + if config.External.AnonymousUsers.Enabled { + // delete anonymous users older than 30 days + c.cleanupStatements = append(c.cleanupStatements, + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '30 days' and is_anonymous is true limit 100 for update skip locked);", tableUsers, tableUsers), + ) + } + + if config.Sessions.Timebox != nil { + timeboxSeconds := int((*config.Sessions.Timebox).Seconds()) + + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where created_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, timeboxSeconds)) + } + + if config.Sessions.InactivityTimeout != nil { + inactivitySeconds := int((*config.Sessions.InactivityTimeout).Seconds()) + + // delete sessions with a refreshed_at column + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where refreshed_at is not null and refreshed_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, inactivitySeconds)) + + // delete sessions without a refreshed_at column by looking for + // unrevoked refresh_tokens + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select %q.id as id from %q, %q where %q.session_id = %q.id and %q.refreshed_at is null and %q.revoked is false and %q.updated_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked)", tableSessions, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, inactivitySeconds)) + } + + meter := otel.Meter("gotrue") + + _, err := meter.Int64ObservableCounter( + "gotrue_cleanup_affected_rows", + metric.WithDescription("Number of affected rows from cleaning up stale entities"), + metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error { + o.Observe(c.cleanupAffectedRows.Load()) + return nil + }), + ) + + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric") + } + + return c +} + +// Cleanup removes stale entities in the database. You can call it on each +// request or as a periodic background job. It does quick lockless updates or +// deletes, has an execution timeout and acquire timeout so that cleanups do +// not affect performance of other database jobs. Note that calling this does +// not clean up the whole database, but does a small piecemeal clean up each +// time when called. +func (c *Cleanup) Clean(db *storage.Connection) (int, error) { + ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup") + defer span.End() + + affectedRows := 0 + defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows))) + + if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error { + nextIndex := atomic.AddUint32(&c.cleanupNext, 1) % uint32(len(c.cleanupStatements)) // #nosec G115 + statement := c.cleanupStatements[nextIndex] + + count, terr := tx.RawQuery(statement).ExecWithCount() + if terr != nil { + return terr + } + + affectedRows += count + + return nil + }); err != nil { + return affectedRows, err + } + c.cleanupAffectedRows.Add(int64(affectedRows)) + + return affectedRows, nil +} diff --git a/internal/models/cleanup_test.go b/internal/models/cleanup_test.go new file mode 100644 index 000000000..618fbbaa5 --- /dev/null +++ b/internal/models/cleanup_test.go @@ -0,0 +1,31 @@ +package models + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage/test" +) + +func TestCleanup(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + timebox := 10 * time.Second + inactivityTimeout := 5 * time.Second + globalConfig.Sessions.Timebox = &timebox + globalConfig.Sessions.InactivityTimeout = &inactivityTimeout + globalConfig.External.AnonymousUsers.Enabled = true + + cleanup := NewCleanup(globalConfig) + + for i := 0; i < 100; i += 1 { + _, err := cleanup.Clean(conn) + require.NoError(t, err) + } +} diff --git a/internal/models/connection.go b/internal/models/connection.go new file mode 100644 index 000000000..80acccc57 --- /dev/null +++ b/internal/models/connection.go @@ -0,0 +1,62 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/supabase/auth/internal/storage" +) + +type Pagination struct { + Page uint64 + PerPage uint64 + Count uint64 +} + +func (p *Pagination) Offset() uint64 { + return (p.Page - 1) * p.PerPage +} + +type SortDirection string + +const Ascending SortDirection = "ASC" +const Descending SortDirection = "DESC" +const CreatedAt = "created_at" + +type SortParams struct { + Fields []SortField +} + +type SortField struct { + Name string + Dir SortDirection +} + +// TruncateAll deletes all data from the database, as managed by GoTrue. Not +// intended for use outside of tests. +func TruncateAll(conn *storage.Connection) error { + return conn.Transaction(func(tx *storage.Connection) error { + tables := []string{ + (&pop.Model{Value: User{}}).TableName(), + (&pop.Model{Value: Identity{}}).TableName(), + (&pop.Model{Value: RefreshToken{}}).TableName(), + (&pop.Model{Value: AuditLogEntry{}}).TableName(), + (&pop.Model{Value: Session{}}).TableName(), + (&pop.Model{Value: Factor{}}).TableName(), + (&pop.Model{Value: Challenge{}}).TableName(), + (&pop.Model{Value: AMRClaim{}}).TableName(), + (&pop.Model{Value: SSOProvider{}}).TableName(), + (&pop.Model{Value: SSODomain{}}).TableName(), + (&pop.Model{Value: SAMLProvider{}}).TableName(), + (&pop.Model{Value: SAMLRelayState{}}).TableName(), + (&pop.Model{Value: FlowState{}}).TableName(), + (&pop.Model{Value: OneTimeToken{}}).TableName(), + } + + for _, tableName := range tables { + if err := tx.RawQuery("DELETE FROM " + tableName + " CASCADE").Exec(); err != nil { + return err + } + } + + return nil + }) +} diff --git a/internal/models/db_test.go b/internal/models/db_test.go new file mode 100644 index 000000000..c3d6ab250 --- /dev/null +++ b/internal/models/db_test.go @@ -0,0 +1,24 @@ +package models + +import ( + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/stretchr/testify/assert" +) + +func TestTableNameNamespacing(t *testing.T) { + cases := []struct { + expected string + value interface{} + }{ + {expected: "audit_log_entries", value: []*AuditLogEntry{}}, + {expected: "refresh_tokens", value: []*RefreshToken{}}, + {expected: "users", value: []*User{}}, + } + + for _, tc := range cases { + m := &pop.Model{Value: tc.value} + assert.Equal(t, tc.expected, m.TableName()) + } +} diff --git a/internal/models/errors.go b/internal/models/errors.go new file mode 100644 index 000000000..96f831969 --- /dev/null +++ b/internal/models/errors.go @@ -0,0 +1,125 @@ +package models + +// IsNotFoundError returns whether an error represents a "not found" error. +func IsNotFoundError(err error) bool { + switch err.(type) { + case UserNotFoundError, *UserNotFoundError: + return true + case SessionNotFoundError, *SessionNotFoundError: + return true + case ConfirmationTokenNotFoundError, *ConfirmationTokenNotFoundError: + return true + case ConfirmationOrRecoveryTokenNotFoundError, *ConfirmationOrRecoveryTokenNotFoundError: + return true + case RefreshTokenNotFoundError, *RefreshTokenNotFoundError: + return true + case IdentityNotFoundError, *IdentityNotFoundError: + return true + case ChallengeNotFoundError, *ChallengeNotFoundError: + return true + case FactorNotFoundError, *FactorNotFoundError: + return true + case SSOProviderNotFoundError, *SSOProviderNotFoundError: + return true + case SAMLRelayStateNotFoundError, *SAMLRelayStateNotFoundError: + return true + case FlowStateNotFoundError, *FlowStateNotFoundError: + return true + case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: + return true + } + return false +} + +type SessionNotFoundError struct{} + +func (e SessionNotFoundError) Error() string { + return "Session not found" +} + +// UserNotFoundError represents when a user is not found. +type UserNotFoundError struct{} + +func (e UserNotFoundError) Error() string { + return "User not found" +} + +// IdentityNotFoundError represents when an identity is not found. +type IdentityNotFoundError struct{} + +func (e IdentityNotFoundError) Error() string { + return "Identity not found" +} + +// ConfirmationOrRecoveryTokenNotFoundError represents when a confirmation or recovery token is not found. +type ConfirmationOrRecoveryTokenNotFoundError struct{} + +func (e ConfirmationOrRecoveryTokenNotFoundError) Error() string { + return "Confirmation or Recovery Token not found" +} + +// ConfirmationTokenNotFoundError represents when a confirmation token is not found. +type ConfirmationTokenNotFoundError struct{} + +func (e ConfirmationTokenNotFoundError) Error() string { + return "Confirmation Token not found" +} + +// RefreshTokenNotFoundError represents when a refresh token is not found. +type RefreshTokenNotFoundError struct{} + +func (e RefreshTokenNotFoundError) Error() string { + return "Refresh Token not found" +} + +// FactorNotFoundError represents when a user is not found. +type FactorNotFoundError struct{} + +func (e FactorNotFoundError) Error() string { + return "Factor not found" +} + +// ChallengeNotFoundError represents when a user is not found. +type ChallengeNotFoundError struct{} + +func (e ChallengeNotFoundError) Error() string { + return "Challenge not found" +} + +// SSOProviderNotFoundError represents an error when a SSO Provider can't be +// found. +type SSOProviderNotFoundError struct{} + +func (e SSOProviderNotFoundError) Error() string { + return "SSO Identity Provider not found" +} + +// SAMLRelayStateNotFoundError represents an error when a SAML relay state +// can't be found. +type SAMLRelayStateNotFoundError struct{} + +func (e SAMLRelayStateNotFoundError) Error() string { + return "SAML RelayState not found" +} + +// FlowStateNotFoundError represents an error when an FlowState can't be +// found. +type FlowStateNotFoundError struct{} + +func (e FlowStateNotFoundError) Error() string { + return "Flow State not found" +} + +func IsUniqueConstraintViolatedError(err error) bool { + switch err.(type) { + case UserEmailUniqueConflictError, *UserEmailUniqueConflictError: + return true + } + return false +} + +type UserEmailUniqueConflictError struct{} + +func (e UserEmailUniqueConflictError) Error() string { + return "User email unique constraint violated" +} diff --git a/internal/models/factor.go b/internal/models/factor.go new file mode 100644 index 000000000..72309a855 --- /dev/null +++ b/internal/models/factor.go @@ -0,0 +1,404 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" +) + +type FactorState int + +const ( + FactorStateUnverified FactorState = iota + FactorStateVerified +) + +func (factorState FactorState) String() string { + switch factorState { + case FactorStateUnverified: + return "unverified" + case FactorStateVerified: + return "verified" + } + return "" +} + +const TOTP = "totp" +const Phone = "phone" +const WebAuthn = "webauthn" + +type AuthenticationMethod int + +const ( + OAuth AuthenticationMethod = iota + PasswordGrant + OTP + TOTPSignIn + MFAPhone + MFAWebAuthn + SSOSAML + Recovery + Invite + MagicLink + EmailSignup + EmailChange + TokenRefresh + Anonymous + Web3 +) + +func (authMethod AuthenticationMethod) String() string { + switch authMethod { + case OAuth: + return "oauth" + case PasswordGrant: + return "password" + case OTP: + return "otp" + case TOTPSignIn: + return "totp" + case Recovery: + return "recovery" + case Invite: + return "invite" + case SSOSAML: + return "sso/saml" + case MagicLink: + return "magiclink" + case EmailSignup: + return "email/signup" + case EmailChange: + return "email_change" + case TokenRefresh: + return "token_refresh" + case Anonymous: + return "anonymous" + case MFAPhone: + return "mfa/phone" + case MFAWebAuthn: + return "mfa/webauthn" + case Web3: + return "web3" + } + return "" +} + +func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) { + if strings.HasSuffix(authMethod, "signup") { + authMethod = "email/signup" + } + switch authMethod { + case "oauth": + return OAuth, nil + case "password": + return PasswordGrant, nil + case "otp": + return OTP, nil + case "totp": + return TOTPSignIn, nil + case "recovery": + return Recovery, nil + case "invite": + return Invite, nil + case "sso/saml": + return SSOSAML, nil + case "magiclink": + return MagicLink, nil + case "email/signup": + return EmailSignup, nil + case "email_change": + return EmailChange, nil + case "token_refresh": + return TokenRefresh, nil + case "mfa/sms": + return MFAPhone, nil + case "mfa/webauthn": + return MFAWebAuthn, nil + case "web3": + return Web3, nil + + } + return 0, fmt.Errorf("unsupported authentication method %q", authMethod) +} + +type Factor struct { + ID uuid.UUID `json:"id" db:"id"` + // TODO: Consider removing this nested user field. We don't use it. + User User `json:"-" belongs_to:"user"` + UserID uuid.UUID `json:"-" db:"user_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + Status string `json:"status" db:"status"` + FriendlyName string `json:"friendly_name,omitempty" db:"friendly_name"` + Secret string `json:"-" db:"secret"` + FactorType string `json:"factor_type" db:"factor_type"` + Challenge []Challenge `json:"-" has_many:"challenges"` + Phone storage.NullString `json:"phone" db:"phone"` + LastChallengedAt *time.Time `json:"last_challenged_at" db:"last_challenged_at"` + WebAuthnCredential *WebAuthnCredential `json:"-" db:"web_authn_credential"` + WebAuthnAAGUID *uuid.UUID `json:"web_authn_aaguid,omitempty" db:"web_authn_aaguid"` +} + +type WebAuthnCredential struct { + webauthn.Credential +} + +func (wc *WebAuthnCredential) Value() (driver.Value, error) { + if wc == nil { + return nil, nil + } + return json.Marshal(wc) +} + +func (wc *WebAuthnCredential) Scan(value interface{}) error { + if value == nil { + wc.Credential = webauthn.Credential{} + return nil + } + // Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte + var data []byte + switch v := value.(type) { + case []byte: + data = v + case string: + data = []byte(v) + default: + return fmt.Errorf("unsupported type for web_authn_credential: %T", value) + } + if len(data) == 0 { + wc.Credential = webauthn.Credential{} + return nil + } + return json.Unmarshal(data, &wc.Credential) +} + +func (Factor) TableName() string { + tableName := "mfa_factors" + return tableName +} + +func NewFactor(user *User, friendlyName string, factorType string, state FactorState) *Factor { + id := uuid.Must(uuid.NewV4()) + + factor := &Factor{ + ID: id, + UserID: user.ID, + Status: state.String(), + FriendlyName: friendlyName, + FactorType: factorType, + } + return factor +} + +func NewTOTPFactor(user *User, friendlyName string) *Factor { + return NewFactor(user, friendlyName, TOTP, FactorStateUnverified) +} + +func NewPhoneFactor(user *User, phone, friendlyName string) *Factor { + factor := NewFactor(user, friendlyName, Phone, FactorStateUnverified) + factor.Phone = storage.NullString(phone) + return factor +} + +func NewWebAuthnFactor(user *User, friendlyName string) *Factor { + factor := NewFactor(user, friendlyName, WebAuthn, FactorStateUnverified) + return factor +} + +func (f *Factor) SetSecret(secret string, encrypt bool, encryptionKeyID, encryptionKey string) error { + f.Secret = secret + if encrypt { + es, err := crypto.NewEncryptedString(f.ID.String(), []byte(secret), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + f.Secret = es.String() + } + + return nil +} + +func (f *Factor) GetSecret(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) { + if es := crypto.ParseEncryptedString(f.Secret); es != nil { + bytes, err := es.Decrypt(f.ID.String(), decryptionKeys) + if err != nil { + return "", false, err + } + + return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil + } + + return f.Secret, encrypt, nil +} + +func (f *Factor) SaveWebAuthnCredential(tx *storage.Connection, credential *webauthn.Credential) error { + f.WebAuthnCredential = &WebAuthnCredential{ + Credential: *credential, + } + + if len(credential.Authenticator.AAGUID) > 0 { + aaguidUUID, err := uuid.FromBytes(credential.Authenticator.AAGUID) + if err != nil { + return fmt.Errorf("WebAuthn authenticator AAGUID is not UUID: %w", err) + } + f.WebAuthnAAGUID = &aaguidUUID + } else { + f.WebAuthnAAGUID = nil + } + + return tx.UpdateOnly(f, "web_authn_credential", "web_authn_aaguid", "updated_at") +} + +func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := conn.Find(&factor, factorID) + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, FactorNotFoundError{} + } else if err != nil { + return nil, err + } + return &factor, nil +} + +func DeleteUnverifiedFactors(tx *storage.Connection, user *User, factorType string) error { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ? and factor_type = ?", user.ID, FactorStateUnverified.String(), factorType).Exec(); err != nil { + return err + } + + return nil +} + +func (f *Factor) CreateChallenge(ipAddress string) *Challenge { + id := uuid.Must(uuid.NewV4()) + challenge := &Challenge{ + ID: id, + FactorID: f.ID, + IPAddress: ipAddress, + } + + return challenge +} +func (f *Factor) WriteChallengeToDatabase(tx *storage.Connection, challenge *Challenge) error { + if challenge.FactorID != f.ID { + return errors.New("Can only write challenges that you own") + } + now := time.Now() + f.LastChallengedAt = &now + if terr := tx.Create(challenge); terr != nil { + return terr + } + if err := tx.UpdateOnly(f, "last_challenged_at"); err != nil { + return err + } + return nil +} + +func (f *Factor) CreatePhoneChallenge(ipAddress string, otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) (*Challenge, error) { + phoneChallenge := f.CreateChallenge(ipAddress) + if err := phoneChallenge.SetOtpCode(otpCode, encrypt, encryptionKeyID, encryptionKey); err != nil { + return nil, err + } + return phoneChallenge, nil +} + +// UpdateFriendlyName changes the friendly name +func (f *Factor) UpdateFriendlyName(tx *storage.Connection, friendlyName string) error { + f.FriendlyName = friendlyName + return tx.UpdateOnly(f, "friendly_name", "updated_at") +} + +func (f *Factor) UpdatePhone(tx *storage.Connection, phone string) error { + f.Phone = storage.NullString(phone) + return tx.UpdateOnly(f, "phone", "updated_at") +} + +// UpdateStatus modifies the factor status +func (f *Factor) UpdateStatus(tx *storage.Connection, state FactorState) error { + f.Status = state.String() + return tx.UpdateOnly(f, "status", "updated_at") +} + +func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error { + sessions, err := FindSessionsByFactorID(tx, f.ID) + if err != nil { + return err + } + for _, session := range sessions { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: AMRClaim{}}).TableName()+" WHERE session_id = ? AND authentication_method = ?", session.ID, f.FactorType).Exec(); err != nil { + return err + } + } + return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String()) +} + +func (f *Factor) IsVerified() bool { + return f.Status == FactorStateVerified.String() +} + +func (f *Factor) IsUnverified() bool { + return f.Status == FactorStateUnverified.String() +} + +func (f *Factor) IsPhoneFactor() bool { + return f.FactorType == Phone +} + +func (f *Factor) FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { + var challenge Challenge + err := conn.Q().Where("id = ? and factor_id = ?", challengeID, f.ID).First(&challenge) + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, ChallengeNotFoundError{} + } else if err != nil { + return nil, err + } + return &challenge, nil +} + +func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ?", userId).Exec(); err != nil { + return err + } + return nil +} + +func DeleteExpiredFactors(tx *storage.Connection, validityDuration time.Duration) error { + totalSeconds := int64(validityDuration / time.Second) + validityInterval := fmt.Sprintf("interval '%d seconds'", totalSeconds) + + factorTable := (&pop.Model{Value: Factor{}}).TableName() + challengeTable := (&pop.Model{Value: Challenge{}}).TableName() + + query := fmt.Sprintf(`delete from %q where status != 'verified' and not exists (select * from %q where %q.id = %q.factor_id ) and created_at + %s < current_timestamp;`, factorTable, challengeTable, factorTable, challengeTable, validityInterval) + if err := tx.RawQuery(query).Exec(); err != nil { + return err + } + return nil +} + +func (f *Factor) FindLatestUnexpiredChallenge(tx *storage.Connection, expiryDuration float64) (*Challenge, error) { + now := time.Now() + var challenge Challenge + expirationTime := now.Add(time.Duration(expiryDuration) * time.Second) + + err := tx.Where("sent_at > ? and factor_id = ?", expirationTime, f.ID). + Order("sent_at desc"). + First(&challenge) + + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, ChallengeNotFoundError{} + } else if err != nil { + return nil, err + } + return &challenge, nil +} diff --git a/internal/models/factor_test.go b/internal/models/factor_test.go new file mode 100644 index 000000000..614cff239 --- /dev/null +++ b/internal/models/factor_test.go @@ -0,0 +1,74 @@ +package models + +import ( + "encoding/json" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type FactorTestSuite struct { + suite.Suite + db *storage.Connection + TestFactor *Factor +} + +func TestFactor(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + ts := &FactorTestSuite{ + db: conn, + } + defer ts.db.Close() + suite.Run(t, ts) +} + +func (ts *FactorTestSuite) SetupTest() { + TruncateAll(ts.db) + user, err := NewUser("", "agenericemail@gmail.com", "secret", "test", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + factor := NewTOTPFactor(user, "asimplename") + require.NoError(ts.T(), factor.SetSecret("topsecret", false, "", "")) + require.NoError(ts.T(), ts.db.Create(factor)) + ts.TestFactor = factor +} + +func (ts *FactorTestSuite) TestFindFactorByFactorID() { + n, err := FindFactorByFactorID(ts.db, ts.TestFactor.ID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.TestFactor.ID, n.ID) + + _, err = FindFactorByFactorID(ts.db, uuid.Nil) + require.EqualError(ts.T(), err, FactorNotFoundError{}.Error()) +} + +func (ts *FactorTestSuite) TestUpdateStatus() { + newFactorStatus := FactorStateVerified + require.NoError(ts.T(), ts.TestFactor.UpdateStatus(ts.db, newFactorStatus)) + require.Equal(ts.T(), newFactorStatus.String(), ts.TestFactor.Status) +} + +func (ts *FactorTestSuite) TestUpdateFriendlyName() { + newName := "newfactorname" + require.NoError(ts.T(), ts.TestFactor.UpdateFriendlyName(ts.db, newName)) + require.Equal(ts.T(), newName, ts.TestFactor.FriendlyName) +} + +func (ts *FactorTestSuite) TestEncodedFactorDoesNotLeakSecret() { + encodedFactor, err := json.Marshal(ts.TestFactor) + require.NoError(ts.T(), err) + + decodedFactor := Factor{} + json.Unmarshal(encodedFactor, &decodedFactor) + require.Equal(ts.T(), decodedFactor.Secret, "") +} diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go new file mode 100644 index 000000000..9a770d81d --- /dev/null +++ b/internal/models/flow_state.go @@ -0,0 +1,169 @@ +package models + +import ( + "crypto/sha256" + "crypto/subtle" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" + + "github.com/gofrs/uuid" +) + +const InvalidCodeChallengeError = "code challenge does not match previously saved code verifier" +const InvalidCodeMethodError = "code challenge method not supported" + +type FlowState struct { + ID uuid.UUID `json:"id" db:"id"` + UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"` + AuthCode string `json:"auth_code" db:"auth_code"` + AuthenticationMethod string `json:"authentication_method" db:"authentication_method"` + CodeChallenge string `json:"code_challenge" db:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method" db:"code_challenge_method"` + ProviderType string `json:"provider_type" db:"provider_type"` + ProviderAccessToken string `json:"provider_access_token" db:"provider_access_token"` + ProviderRefreshToken string `json:"provider_refresh_token" db:"provider_refresh_token"` + AuthCodeIssuedAt *time.Time `json:"auth_code_issued_at" db:"auth_code_issued_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type CodeChallengeMethod int + +const ( + SHA256 CodeChallengeMethod = iota + Plain +) + +func (codeChallengeMethod CodeChallengeMethod) String() string { + switch codeChallengeMethod { + case SHA256: + return "s256" + case Plain: + return "plain" + } + return "" +} + +func ParseCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) { + switch strings.ToLower(codeChallengeMethod) { + case "s256": + return SHA256, nil + case "plain": + return Plain, nil + } + return 0, fmt.Errorf("unsupported code_challenge method %q", codeChallengeMethod) +} + +type FlowType int + +const ( + PKCEFlow FlowType = iota + ImplicitFlow +) + +func (flowType FlowType) String() string { + switch flowType { + case PKCEFlow: + return "pkce" + case ImplicitFlow: + return "implicit" + } + return "" +} + +func (FlowState) TableName() string { + tableName := "flow_state" + return tableName +} + +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState { + id := uuid.Must(uuid.NewV4()) + authCode := uuid.Must(uuid.NewV4()) + flowState := &FlowState{ + ID: id, + ProviderType: providerType, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod.String(), + AuthCode: authCode.String(), + AuthenticationMethod: authenticationMethod.String(), + UserID: userID, + } + return flowState +} + +func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("auth_code = ?", authCode).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func FindFlowStateByID(tx *storage.Connection, id string) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("id = ?", id).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMethod AuthenticationMethod) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("user_id = ? and authentication_method = ?", id, authenticationMethod).Last(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func (f *FlowState) VerifyPKCE(codeVerifier string) error { + switch f.CodeChallengeMethod { + case SHA256.String(): + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(encodedCodeVerifier)) != 1 { + return errors.New(InvalidCodeChallengeError) + } + case Plain.String(): + if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(codeVerifier)) != 1 { + return errors.New(InvalidCodeChallengeError) + } + default: + return errors.New(InvalidCodeMethodError) + + } + return nil +} + +func (f *FlowState) IsExpired(expiryDuration time.Duration) bool { + if f.AuthCodeIssuedAt != nil && f.AuthenticationMethod == MagicLink.String() { + return time.Now().After(f.AuthCodeIssuedAt.Add(expiryDuration)) + } + return time.Now().After(f.CreatedAt.Add(expiryDuration)) +} + +func (f *FlowState) RecordAuthCodeIssuedAtTime(tx *storage.Connection) error { + issueTime := time.Now() + f.AuthCodeIssuedAt = &issueTime + if err := tx.Update(f); err != nil { + return err + } + return nil +} diff --git a/internal/models/identity.go b/internal/models/identity.go new file mode 100644 index 000000000..c647cbc22 --- /dev/null +++ b/internal/models/identity.go @@ -0,0 +1,142 @@ +package models + +import ( + "database/sql" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type Identity struct { + // returned as identity_id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ID uuid.UUID `json:"identity_id" db:"id"` + // returned as id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ProviderID string `json:"id" db:"provider_id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + IdentityData JSONMap `json:"identity_data,omitempty" db:"identity_data"` + Provider string `json:"provider" db:"provider"` + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + Email storage.NullString `json:"email,omitempty" db:"email" rw:"r"` +} + +func (Identity) TableName() string { + tableName := "identities" + return tableName +} + +// GetEmail returns the user's email as a string +func (i *Identity) GetEmail() string { + return string(i.Email) +} + +// NewIdentity returns an identity associated to the user's id. +func NewIdentity(user *User, provider string, identityData map[string]interface{}) (*Identity, error) { + providerId, ok := identityData["sub"] + if !ok { + return nil, errors.New("error missing provider id") + } + now := time.Now() + + identity := &Identity{ + ProviderID: providerId.(string), + UserID: user.ID, + IdentityData: identityData, + Provider: provider, + LastSignInAt: &now, + } + if email, ok := identityData["email"]; ok { + identity.Email = storage.NullString(email.(string)) + } + + return identity, nil +} + +func (i *Identity) BeforeCreate(tx *pop.Connection) error { + return i.BeforeUpdate(tx) +} + +func (i *Identity) BeforeUpdate(tx *pop.Connection) error { + if _, ok := i.IdentityData["email"]; ok { + i.IdentityData["email"] = strings.ToLower(i.IdentityData["email"].(string)) + } + return nil +} + +func (i *Identity) IsForSSOProvider() bool { + return strings.HasPrefix(i.Provider, "sso:") +} + +// FindIdentityById searches for an identity with the matching id and provider given. +func FindIdentityByIdAndProvider(tx *storage.Connection, providerId, provider string) (*Identity, error) { + identity := &Identity{} + if err := tx.Q().Where("provider_id = ? AND provider = ?", providerId, provider).First(identity); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, IdentityNotFoundError{} + } + return nil, errors.Wrap(err, "error finding identity") + } + return identity, nil +} + +// FindIdentitiesByUserID returns all identities associated to a user ID. +func FindIdentitiesByUserID(tx *storage.Connection, userID uuid.UUID) ([]*Identity, error) { + identities := []*Identity{} + if err := tx.Q().Where("user_id = ?", userID).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return identities, nil + } + return nil, errors.Wrap(err, "error finding identities") + } + return identities, nil +} + +// FindProvidersByUser returns all providers associated to a user +func FindProvidersByUser(tx *storage.Connection, user *User) ([]string, error) { + identities := []Identity{} + providerExists := map[string]bool{} + providers := make([]string, 0) + if err := tx.Q().Select("provider").Where("user_id = ?", user.ID).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return providers, nil + } + return nil, errors.Wrap(err, "error finding providers") + } + for _, identity := range identities { + if _, ok := providerExists[identity.Provider]; !ok { + providers = append(providers, identity.Provider) + providerExists[identity.Provider] = true + } + } + return providers, nil +} + +// UpdateIdentityData sets all identity_data from a map of updates, +// ensuring that it doesn't override attributes that are not +// in the provided map. +func (i *Identity) UpdateIdentityData(tx *storage.Connection, updates map[string]interface{}) error { + if i.IdentityData == nil { + i.IdentityData = updates + } else { + for key, value := range updates { + if value != nil { + i.IdentityData[key] = value + } else { + delete(i.IdentityData, key) + } + } + } + // pop doesn't support updates on tables with composite primary keys so we use a raw query here. + return tx.RawQuery( + "update "+(&pop.Model{Value: Identity{}}).TableName()+" set identity_data = ? where id = ?", + i.IdentityData, + i.ID, + ).Exec() +} diff --git a/models/identity_test.go b/internal/models/identity_test.go similarity index 68% rename from models/identity_test.go rename to internal/models/identity_test.go index c94b6a567..d27d17b67 100644 --- a/models/identity_test.go +++ b/internal/models/identity_test.go @@ -4,11 +4,11 @@ import ( "testing" "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/netlify/gotrue/storage/test" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" ) type IdentityTestSuite struct { @@ -53,15 +53,38 @@ func (ts *IdentityTestSuite) TestNewIdentity() { func (ts *IdentityTestSuite) TestFindUserIdentities() { u := ts.createUserWithIdentity("test@supabase.io") - identities, err := FindIdentitiesByUser(ts.db, u) + identities, err := FindIdentitiesByUserID(ts.db, u.ID) require.NoError(ts.T(), err) require.Len(ts.T(), identities, 1) } +func (ts *IdentityTestSuite) TestUpdateIdentityData() { + u := ts.createUserWithIdentity("test@supabase.io") + + identities, err := FindIdentitiesByUserID(ts.db, u.ID) + require.NoError(ts.T(), err) + + updates := map[string]interface{}{ + "sub": nil, + "name": nil, + "email": nil, + } + for _, identity := range identities { + err := identity.UpdateIdentityData(ts.db, updates) + require.NoError(ts.T(), err) + } + + updatedIdentities, err := FindIdentitiesByUserID(ts.db, u.ID) + require.NoError(ts.T(), err) + for _, identity := range updatedIdentities { + require.Empty(ts.T(), identity.IdentityData) + } +} + func (ts *IdentityTestSuite) createUserWithEmail(email string) *User { - user, err := NewUser(uuid.Nil, email, "secret", "test", nil) + user, err := NewUser("", email, "secret", "test", nil) require.NoError(ts.T(), err) err = ts.db.Create(user) @@ -71,7 +94,7 @@ func (ts *IdentityTestSuite) createUserWithEmail(email string) *User { } func (ts *IdentityTestSuite) createUserWithIdentity(email string) *User { - user, err := NewUser(uuid.Nil, email, "secret", "test", nil) + user, err := NewUser("", email, "secret", "test", nil) require.NoError(ts.T(), err) err = ts.db.Create(user) diff --git a/models/json_map.go b/internal/models/json_map.go similarity index 86% rename from models/json_map.go rename to internal/models/json_map.go index 6db3c998b..77cee649b 100644 --- a/models/json_map.go +++ b/internal/models/json_map.go @@ -23,8 +23,10 @@ func (j JSONMap) Scan(src interface{}) error { source = []byte(v) case []byte: source = v + case nil: + source = []byte("") default: - return errors.New("Invalid data type for JSONMap") + return errors.New("invalid data type for JSONMap") } if len(source) == 0 { diff --git a/internal/models/linking.go b/internal/models/linking.go new file mode 100644 index 000000000..ca794bc5d --- /dev/null +++ b/internal/models/linking.go @@ -0,0 +1,203 @@ +package models + +import ( + "strings" + + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +// GetAccountLinkingDomain returns a string that describes the account linking +// domain. An account linking domain describes a set of Identity entities that +// _should_ generally fall under the same User entity. It's just a runtime +// string, and is not typically persisted in the database. This value can vary +// across time. +func GetAccountLinkingDomain(provider string) string { + if strings.HasPrefix(provider, "sso:") { + // when the provider ID is a SSO provider, then the linking + // domain is the provider itself i.e. there can only be one + // user + identity per identity provider + return provider + } + + // otherwise, the linking domain is the default linking domain that + // links all accounts + return "default" +} + +type AccountLinkingDecision = int + +const ( + AccountExists AccountLinkingDecision = iota + CreateAccount + LinkAccount + MultipleAccounts +) + +type AccountLinkingResult struct { + Decision AccountLinkingDecision + User *User + Identities []*Identity + LinkingDomain string + CandidateEmail provider.Email +} + +// DetermineAccountLinking uses the provided data and database state to compute a decision on whether: +// - A new User should be created (CreateAccount) +// - A new Identity should be created (LinkAccount) with a UserID pointing to an existing user account +// - Nothing should be done (AccountExists) +// - It's not possible to decide due to data inconsistency (MultipleAccounts) and the caller should decide +// +// Errors signal failure in processing only, like database access errors. +func DetermineAccountLinking(tx *storage.Connection, config *conf.GlobalConfiguration, emails []provider.Email, aud, providerName, sub string) (AccountLinkingResult, error) { + var verifiedEmails []string + var candidateEmail provider.Email + for _, email := range emails { + if email.Verified || config.Mailer.Autoconfirm { + verifiedEmails = append(verifiedEmails, strings.ToLower(email.Email)) + } + if email.Primary { + candidateEmail = email + candidateEmail.Email = strings.ToLower(email.Email) + } + } + + if identity, terr := FindIdentityByIdAndProvider(tx, sub, providerName); terr == nil { + // account exists + + var user *User + if user, terr = FindUserByID(tx, identity.UserID); terr != nil { + return AccountLinkingResult{}, terr + } + + // we overwrite the email with the existing user's email since the user + // could have an empty email + candidateEmail.Email = user.GetEmail() + return AccountLinkingResult{ + Decision: AccountExists, + User: user, + Identities: []*Identity{identity}, + LinkingDomain: GetAccountLinkingDomain(providerName), + CandidateEmail: candidateEmail, + }, nil + } else if !IsNotFoundError(terr) { + return AccountLinkingResult{}, terr + } + + // the identity does not exist, so we need to check if we should create a new account + // or link to an existing one + + // this is the linking domain for the new identity + candidateLinkingDomain := GetAccountLinkingDomain(providerName) + if len(verifiedEmails) == 0 { + // if there are no verified emails, we always decide to create a new account + user, terr := IsDuplicatedEmail(tx, candidateEmail.Email, aud, nil) + if terr != nil { + return AccountLinkingResult{}, terr + } + if user != nil { + candidateEmail.Email = "" + } + return AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + + var similarIdentities []*Identity + var similarUsers []*User + // look for similar identities and users based on email + if terr := tx.Q().Eager().Where("email = any (?)", verifiedEmails).All(&similarIdentities); terr != nil { + return AccountLinkingResult{}, terr + } + + if !strings.HasPrefix(providerName, "sso:") { + // there can be multiple user accounts with the same email when is_sso_user is true + // so we just do not consider those similar user accounts + if terr := tx.Q().Eager().Where("email = any (?) and is_sso_user = false", verifiedEmails).All(&similarUsers); terr != nil { + return AccountLinkingResult{}, terr + } + } + + // Need to check if the new identity should be assigned to an + // existing user or to create a new user, according to the automatic + // linking rules + var linkingIdentities []*Identity + + // now let's see if there are any existing and similar identities in + // the same linking domain + for _, identity := range similarIdentities { + if GetAccountLinkingDomain(identity.Provider) == candidateLinkingDomain { + linkingIdentities = append(linkingIdentities, identity) + } + } + + if len(linkingIdentities) == 0 { + if len(similarUsers) == 1 { + // no similarIdentities but a user with the same email exists + // so we link this new identity to the user + // TODO: Backfill the missing identity for the user + return AccountLinkingResult{ + Decision: LinkAccount, + User: similarUsers[0], + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } else if len(similarUsers) > 1 { + // this shouldn't happen since there is a partial unique index on (email and is_sso_user = false) + return AccountLinkingResult{ + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } else { + // there are no identities in the linking domain, we have to + // create a new identity and new user + return AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + } + + // there is at least one identity in the linking domain let's do a + // sanity check to see if all of the identities in the domain share the + // same user ID + linkingUserId := linkingIdentities[0].UserID + for _, identity := range linkingIdentities { + if identity.UserID != linkingUserId { + // ok this linking domain has more than one user account + // caller should decide what to do + + return AccountLinkingResult{ + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + } + + // there's only one user ID in this linking domain, we can go on and + // create a new identity and link it to the existing account + + var user *User + var terr error + + if user, terr = FindUserByID(tx, linkingUserId); terr != nil { + return AccountLinkingResult{}, terr + } + + return AccountLinkingResult{ + Decision: LinkAccount, + User: user, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil +} diff --git a/internal/models/linking_test.go b/internal/models/linking_test.go new file mode 100644 index 000000000..05d4a8c32 --- /dev/null +++ b/internal/models/linking_test.go @@ -0,0 +1,314 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type AccountLinkingTestSuite struct { + suite.Suite + + config *conf.GlobalConfiguration + db *storage.Connection +} + +func (ts *AccountLinkingTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestAccountLinking(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &AccountLinkingTestSuite{ + config: globalConfig, + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionNoAccounts() { + // when there are no accounts in the system -- conventional provider + testEmail := provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + } + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{testEmail}, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + + // when there are no accounts in the system -- SSO provider + decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{testEmail}, ts.config.JWT.Aud, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) +} + +func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test@samltest.id", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + ssoProvider := "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387" + identityB, err := NewIdentity(userB, ssoProvider, map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@samltest.id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + // when the email doesn't exist in the system -- conventional provider + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "other@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + require.Equal(ts.T(), decision.LinkingDomain, "default") + + // when looking for an email that doesn't exist in the SSO linking domain + decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "other@samltest.id", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, ssoProvider, "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + require.Equal(ts.T(), decision.LinkingDomain, ssoProvider) +} + +func (ts *AccountLinkingTestSuite) TestAccountExists() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", userA.ID.String()) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, AccountExists) + require.Equal(ts.T(), decision.User.ID, userA.ID) +} + +func (ts *AccountLinkingTestSuite) TestLinkingScenarios() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test@samltest.id", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + identityB, err := NewIdentity(userB, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@samltest.id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + cases := []struct { + desc string + email provider.Email + sub string + provider string + decision AccountLinkingResult + }{ + { + // link decision because the below described identity is in the default linking domain but uses "other-provider" instead of "provder" + desc: "same email address", + email: provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: LinkAccount, + User: userA, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "same email address in uppercase", + email: provider.Email{ + Email: "TEST@example.com", + Verified: true, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: LinkAccount, + User: userA, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + // expected email should be case insensitive + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "no link decision because the SSO linking domain is scoped to the provider unique ID", + email: provider.Email{ + Email: "test@samltest.id", + Verified: true, + Primary: true, + }, + sub: userB.ID.String(), + provider: "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", + // decision: AccountExists, + decision: AccountLinkingResult{ + Decision: AccountExists, + User: userB, + LinkingDomain: "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", + CandidateEmail: provider.Email{ + Email: "test@samltest.id", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "create account with empty email because email is unverified and user exists", + email: provider.Email{ + Email: "test@example.com", + Verified: false, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "", + Verified: false, + Primary: true, + }, + }, + }, + { + desc: "create account because email is unverified and user doesn't exist", + email: provider.Email{ + Email: "other@example.com", + Verified: false, + Primary: true, + }, + sub: "000000000", + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "other@example.com", + Verified: false, + Primary: true, + }, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{c.email}, ts.config.JWT.Aud, c.provider, c.sub) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.decision.Decision, decision.Decision) + require.Equal(ts.T(), c.decision.LinkingDomain, decision.LinkingDomain) + require.Equal(ts.T(), c.decision.CandidateEmail.Email, decision.CandidateEmail.Email) + require.Equal(ts.T(), c.decision.CandidateEmail.Verified, decision.CandidateEmail.Verified) + require.Equal(ts.T(), c.decision.CandidateEmail.Primary, decision.CandidateEmail.Primary) + }) + } + +} + +func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test-b@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + identityB, err := NewIdentity(userB, "provider", map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@example.com", // intentionally same as userA + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + // decision is multiple accounts because there are two distinct + // identities in the same "default" linking domain with the same email + // address pointing to two different user accounts + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, MultipleAccounts) +} diff --git a/internal/models/one_time_token.go b/internal/models/one_time_token.go new file mode 100644 index 000000000..53805a423 --- /dev/null +++ b/internal/models/one_time_token.go @@ -0,0 +1,286 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type OneTimeTokenType int + +const ( + ConfirmationToken OneTimeTokenType = iota + ReauthenticationToken + RecoveryToken + EmailChangeTokenNew + EmailChangeTokenCurrent + PhoneChangeToken +) + +func (t OneTimeTokenType) String() string { + switch t { + case ConfirmationToken: + return "confirmation_token" + + case ReauthenticationToken: + return "reauthentication_token" + + case RecoveryToken: + return "recovery_token" + + case EmailChangeTokenNew: + return "email_change_token_new" + + case EmailChangeTokenCurrent: + return "email_change_token_current" + + case PhoneChangeToken: + return "phone_change_token" + + default: + panic("OneTimeToken: unreachable case") + } +} + +func ParseOneTimeTokenType(s string) (OneTimeTokenType, error) { + switch s { + case "confirmation_token": + return ConfirmationToken, nil + + case "reauthentication_token": + return ReauthenticationToken, nil + + case "recovery_token": + return RecoveryToken, nil + + case "email_change_token_new": + return EmailChangeTokenNew, nil + + case "email_change_token_current": + return EmailChangeTokenCurrent, nil + + case "phone_change_token": + return PhoneChangeToken, nil + + default: + return 0, fmt.Errorf("OneTimeTokenType: unrecognized string %q", s) + } +} + +func (t OneTimeTokenType) Value() (driver.Value, error) { + return t.String(), nil +} + +func (t *OneTimeTokenType) Scan(src interface{}) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("OneTimeTokenType: scan type is not string but is %T", src) + } + + parsed, err := ParseOneTimeTokenType(s) + if err != nil { + return err + } + + *t = parsed + return nil +} + +type OneTimeTokenNotFoundError struct { +} + +func (e OneTimeTokenNotFoundError) Error() string { + return "One-time token not found" +} + +type OneTimeToken struct { + ID uuid.UUID `json:"id" db:"id"` + + UserID uuid.UUID `json:"user_id" db:"user_id"` + TokenType OneTimeTokenType `json:"token_type" db:"token_type"` + + TokenHash string `json:"token_hash" db:"token_hash"` + RelatesTo string `json:"relates_to" db:"relates_to"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +func (OneTimeToken) TableName() string { + return "one_time_tokens" +} + +func ClearAllOneTimeTokensForUser(tx *storage.Connection, userID uuid.UUID) error { + return tx.Q().Where("user_id = ?", userID).Delete(OneTimeToken{}) +} + +func ClearOneTimeTokenForUser(tx *storage.Connection, userID uuid.UUID, tokenType OneTimeTokenType) error { + if err := tx.Q().Where("token_type = ? and user_id = ?", tokenType, userID).Delete(OneTimeToken{}); err != nil { + return err + } + + return nil +} + +func CreateOneTimeToken(tx *storage.Connection, userID uuid.UUID, relatesTo, tokenHash string, tokenType OneTimeTokenType) error { + if err := ClearOneTimeTokenForUser(tx, userID, tokenType); err != nil { + return err + } + + oneTimeToken := &OneTimeToken{ + ID: uuid.Must(uuid.NewV4()), + UserID: userID, + TokenType: tokenType, + TokenHash: tokenHash, + RelatesTo: strings.ToLower(relatesTo), + } + + if err := tx.Eager().Create(oneTimeToken); err != nil { + return err + } + + return nil +} + +func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...OneTimeTokenType) (*OneTimeToken, error) { + oneTimeToken := &OneTimeToken{} + + query := tx.Eager().Q() + + switch len(tokenTypes) { + case 2: + query = query.Where("(token_type = ? or token_type = ?) and token_hash = ?", tokenTypes[0], tokenTypes[1], tokenHash) + + case 1: + query = query.Where("token_type = ? and token_hash = ?", tokenTypes[0], tokenHash) + + default: + panic("at most 2 token types are accepted") + } + + if err := query.First(oneTimeToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OneTimeTokenNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding one time token") + } + + return oneTimeToken, nil +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByRecoveryToken finds a user with the matching recovery token. +func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, RecoveryToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeToken finds a user with the matching email change token. +func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent) + if err != nil { + return nil, err + } + } + if ott == nil { + return nil, err + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.GetEmail(), email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserByEmailChangeNewAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + } + if ott == nil { + return nil, err + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.EmailChange, email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserForEmailChange finds a user requesting for an email change +func FindUserForEmailChange(tx *storage.Connection, email, token, aud string, secureEmailChangeEnabled bool) (*User, error) { + if secureEmailChangeEnabled { + if user, err := FindUserByEmailChangeCurrentAndAudience(tx, email, token, aud); err == nil { + return user, err + } else if !IsNotFoundError(err) { + return nil, err + } + } + return FindUserByEmailChangeNewAndAudience(tx, email, token, aud) +} diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go new file mode 100644 index 000000000..b683a8c77 --- /dev/null +++ b/internal/models/refresh_token.go @@ -0,0 +1,166 @@ +package models + +import ( + "database/sql" + "net/http" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// RefreshToken is the database model for refresh tokens. +type RefreshToken struct { + ID int64 `db:"id"` + + Token string `db:"token"` + + UserID uuid.UUID `db:"user_id"` + + Parent storage.NullString `db:"parent"` + SessionId *uuid.UUID `db:"session_id"` + + Revoked bool `db:"revoked"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func (RefreshToken) TableName() string { + tableName := "refresh_tokens" + return tableName +} + +// GrantParams is used to pass session-specific parameters when issuing a new +// refresh token to authenticated users. +type GrantParams struct { + FactorID *uuid.UUID + + SessionNotAfter *time.Time + SessionTag *string + + UserAgent string + IP string +} + +func (g *GrantParams) FillGrantParams(r *http.Request) { + g.UserAgent = r.Header.Get("User-Agent") + g.IP = utilities.GetIPAddress(r) +} + +// GrantAuthenticatedUser creates a refresh token for the provided user. +func GrantAuthenticatedUser(tx *storage.Connection, user *User, params GrantParams) (*RefreshToken, error) { + return createRefreshToken(tx, user, nil, ¶ms) +} + +// GrantRefreshTokenSwap swaps a refresh token for a new one, revoking the provided token. +func GrantRefreshTokenSwap(r *http.Request, tx *storage.Connection, user *User, token *RefreshToken) (*RefreshToken, error) { + var newToken *RefreshToken + err := tx.Transaction(func(rtx *storage.Connection) error { + var terr error + if terr = NewAuditLogEntry(r, tx, user, TokenRevokedAction, "", nil); terr != nil { + return errors.Wrap(terr, "error creating audit log entry") + } + + token.Revoked = true + if terr = tx.UpdateOnly(token, "revoked"); terr != nil { + return terr + } + + newToken, terr = createRefreshToken(rtx, user, token, &GrantParams{}) + return terr + }) + return newToken, err +} + +// RevokeTokenFamily revokes all refresh tokens that descended from the provided token. +func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error { + var err error + tablename := (&pop.Model{Value: RefreshToken{}}).TableName() + if token.SessionId != nil { + err = tx.RawQuery(`update `+tablename+` set revoked = true, updated_at = now() where session_id = ? and revoked = false;`, token.SessionId).Exec() + } else { + err = tx.RawQuery(` + with recursive token_family as ( + select id, user_id, token, revoked, parent from `+tablename+` where parent = ? + union + select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent + ) + update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec() + } + if err != nil { + if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) { + return nil + } + + return err + } + return nil +} + +func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*RefreshToken, error) { + refreshToken := &RefreshToken{} + err := tx.Q().Where("instance_id = ? and session_id = ?", uuid.Nil, sessionId).Order("created_at asc").First(refreshToken) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, RefreshTokenNotFoundError{} + } + return nil, err + } + return refreshToken, nil +} + +func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) { + token := &RefreshToken{ + UserID: user.ID, + Token: crypto.SecureAlphanumeric(12), + Parent: "", + } + if oldToken != nil { + token.Parent = storage.NullString(oldToken.Token) + token.SessionId = oldToken.SessionId + } + + if token.SessionId == nil { + session, err := NewSession(user.ID, params.FactorID) + if err != nil { + return nil, errors.Wrap(err, "error instantiating new session object") + } + + if params.SessionNotAfter != nil { + session.NotAfter = params.SessionNotAfter + } + + if params.UserAgent != "" { + session.UserAgent = ¶ms.UserAgent + } + + if params.IP != "" { + session.IP = ¶ms.IP + } + + if params.SessionTag != nil && *params.SessionTag != "" { + session.Tag = params.SessionTag + } + + if err := tx.Create(session); err != nil { + return nil, errors.Wrap(err, "error creating new session") + } + + token.SessionId = &session.ID + } + + if err := tx.Create(token); err != nil { + return nil, errors.Wrap(err, "error creating refresh token") + } + + if err := user.UpdateLastSignInAt(tx); err != nil { + return nil, errors.Wrap(err, "error update user`s last_sign_in field") + } + return token, nil +} diff --git a/models/refresh_token_test.go b/internal/models/refresh_token_test.go similarity index 73% rename from models/refresh_token_test.go rename to internal/models/refresh_token_test.go index 96df7262f..675826d8b 100644 --- a/models/refresh_token_test.go +++ b/internal/models/refresh_token_test.go @@ -1,14 +1,14 @@ package models import ( + "net/http" "testing" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/netlify/gotrue/storage/test" - "github.com/gofrs/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" ) type RefreshTokenTestSuite struct { @@ -37,7 +37,7 @@ func TestRefreshToken(t *testing.T) { func (ts *RefreshTokenTestSuite) TestGrantAuthenticatedUser() { u := ts.createUser() - r, err := GrantAuthenticatedUser(ts.db, u) + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) require.NoError(ts.T(), err) require.NotEmpty(ts.T(), r.Token) @@ -46,13 +46,13 @@ func (ts *RefreshTokenTestSuite) TestGrantAuthenticatedUser() { func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() { u := ts.createUser() - r, err := GrantAuthenticatedUser(ts.db, u) + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) require.NoError(ts.T(), err) - s, err := GrantRefreshTokenSwap(ts.db, u, r) + s, err := GrantRefreshTokenSwap(&http.Request{}, ts.db, u, r) require.NoError(ts.T(), err) - _, nr, err := FindUserWithRefreshToken(ts.db, r.Token) + _, nr, _, err := FindUserWithRefreshToken(ts.db, r.Token, false) require.NoError(ts.T(), err) require.Equal(ts.T(), r.ID, nr.ID) @@ -64,11 +64,11 @@ func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() { func (ts *RefreshTokenTestSuite) TestLogout() { u := ts.createUser() - r, err := GrantAuthenticatedUser(ts.db, u) + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) require.NoError(ts.T(), err) - require.NoError(ts.T(), Logout(ts.db, uuid.Nil, u.ID)) - u, r, err = FindUserWithRefreshToken(ts.db, r.Token) + require.NoError(ts.T(), Logout(ts.db, u.ID)) + u, r, _, err = FindUserWithRefreshToken(ts.db, r.Token, false) require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r) require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError") @@ -79,7 +79,7 @@ func (ts *RefreshTokenTestSuite) createUser() *User { } func (ts *RefreshTokenTestSuite) createUserWithEmail(email string) *User { - user, err := NewUser(uuid.Nil, email, "secret", "test", nil) + user, err := NewUser("", email, "secret", "test", nil) require.NoError(ts.T(), err) err = ts.db.Create(user) diff --git a/internal/models/sessions.go b/internal/models/sessions.go new file mode 100644 index 000000000..a93be44ac --- /dev/null +++ b/internal/models/sessions.go @@ -0,0 +1,356 @@ +package models + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type AuthenticatorAssuranceLevel int + +const ( + AAL1 AuthenticatorAssuranceLevel = iota + AAL2 + AAL3 +) + +func (aal AuthenticatorAssuranceLevel) String() string { + switch aal { + case AAL1: + return "aal1" + case AAL2: + return "aal2" + case AAL3: + return "aal3" + default: + return "" + } +} + +// AMREntry represents a method that a user has logged in together with the corresponding time +type AMREntry struct { + Method string `json:"method"` + Timestamp int64 `json:"timestamp"` + Provider string `json:"provider,omitempty"` +} + +type sortAMREntries struct { + Array []AMREntry +} + +func (s sortAMREntries) Len() int { + return len(s.Array) +} + +func (s sortAMREntries) Less(i, j int) bool { + return s.Array[i].Timestamp < s.Array[j].Timestamp +} + +func (s sortAMREntries) Swap(i, j int) { + s.Array[j], s.Array[i] = s.Array[i], s.Array[j] +} + +type Session struct { + ID uuid.UUID `json:"-" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + + // NotAfter is overriden by timeboxed sessions. + NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` + AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` + AAL *string `json:"aal" db:"aal"` + + RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"` + UserAgent *string `json:"user_agent,omitempty" db:"user_agent"` + IP *string `json:"ip,omitempty" db:"ip"` + + Tag *string `json:"tag" db:"tag"` +} + +func (Session) TableName() string { + tableName := "sessions" + return tableName +} + +func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time { + refreshedAt := s.RefreshedAt + + if refreshedAt == nil || refreshedAt.IsZero() { + if refreshTokenTime != nil { + rtt := *refreshTokenTime + + if rtt.IsZero() { + return s.CreatedAt + } else if rtt.After(s.CreatedAt) { + return rtt + } + } + + return s.CreatedAt + } + + return *refreshedAt +} + +func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error { + // TODO(kangmingtay): The underlying database type uses timestamp without timezone, + // so we need to convert the value to UTC before updating it. + // In the future, we should add a migration to update the type to contain the timezone. + *s.RefreshedAt = s.RefreshedAt.UTC() + return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") +} + +type SessionValidityReason = int + +const ( + SessionValid SessionValidityReason = iota + SessionPastNotAfter = iota + SessionPastTimebox = iota + SessionTimedOut = iota +) + +func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason { + if s.NotAfter != nil && now.After(*s.NotAfter) { + return SessionPastNotAfter + } + + if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) { + return SessionPastTimebox + } + + if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) { + return SessionTimedOut + } + + return SessionValid +} + +func (s *Session) DetermineTag(tags []string) string { + if len(tags) == 0 { + return "" + } + + if s.Tag == nil { + return tags[0] + } + + tag := *s.Tag + if tag == "" { + return tags[0] + } + + for _, t := range tags { + if t == tag { + return tag + } + } + + return tags[0] +} + +func NewSession(userID uuid.UUID, factorID *uuid.UUID) (*Session, error) { + id := uuid.Must(uuid.NewV4()) + + defaultAAL := AAL1.String() + + session := &Session{ + ID: id, + AAL: &defaultAAL, + UserID: userID, + FactorID: factorID, + } + + return session, nil +} + +// FindSessionByID looks up a Session by the provided id. If forUpdate is set +// to true, then the SELECT statement used by the query has the form SELECT ... +// FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE lock will only be +// acquired if there's no other lock. In case there is a lock, a +// IsNotFound(err) error will be retured. +func FindSessionByID(tx *storage.Connection, id uuid.UUID, forUpdate bool) (*Session, error) { + session := &Session{} + + if forUpdate { + // pop does not provide us with a way to execute FOR UPDATE + // queries which lock the rows affected by the query from + // being accessed by any other transaction that also uses FOR + // UPDATE + if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", session.TableName()), id).First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + + return nil, err + } + } + + // once the rows are locked (if forUpdate was true), we can query again using pop + if err := tx.Eager().Q().Where("id = ?", id).First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + return nil, errors.Wrap(err, "error finding session") + } + return session, nil +} + +func FindSessionByUserID(tx *storage.Connection, userId uuid.UUID) (*Session, error) { + session := &Session{} + if err := tx.Eager().Q().Where("user_id = ?", userId).Order("created_at asc").First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + return nil, errors.Wrap(err, "error finding session") + } + return session, nil +} + +func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Session, error) { + sessions := []*Session{} + if err := tx.Q().Where("factor_id = ?", factorID).All(&sessions); err != nil { + return nil, err + } + return sessions, nil +} + +// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is +// set, it will first lock on the user row which can be used to prevent issues +// with concurrency. If the lock is acquired, it will return a +// UserNotFoundError and the operation should be retried. If there are no +// sessions for the user, a nil result is returned without an error. +func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) { + if forUpdate { + user := &User{} + if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, UserNotFoundError{} + } + + return nil, err + } + } + + var sessions []*Session + if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, err + } + + return sessions, nil +} + +func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error { + return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec() +} + +func InvalidateSessionsWithAALLessThan(tx *storage.Connection, userID uuid.UUID, level string) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ? AND aal < ?", userID, level).Exec() +} + +// Logout deletes all sessions for a user. +func Logout(tx *storage.Connection, userId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ?", userId).Exec() +} + +// LogoutSession deletes the current session for a user +func LogoutSession(tx *storage.Connection, sessionId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id = ?", sessionId).Exec() +} + +// LogoutAllExceptMe deletes all sessions for a user except the current one +func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec() +} + +func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error { + s.FactorID = factorID + aalAsString := aal.String() + s.AAL = &aalAsString + return tx.UpdateOnly(s, "aal", "factor_id") +} + +func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) { + amr, aal = []AMREntry{}, AAL1 + for _, claim := range s.AMRClaims { + if claim.IsAAL2Claim() { + aal = AAL2 + } + amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()}) + } + + // makes sure that the AMR claims are always ordered most-recent first + + // sort in ascending order + sort.Sort(sortAMREntries{ + Array: amr, + }) + + // now reverse for descending order + _ = sort.Reverse(sortAMREntries{ + Array: amr, + }) + + lastIndex := len(amr) - 1 + + if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() { + // initial AMR claim is from sso/saml, we need to add information + // about the provider that was used for the authentication + identities := user.Identities + + if len(identities) == 1 { + identity := identities[0] + + if identity.IsForSSOProvider() { + amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:") + } + } + + // otherwise we can't identify that this user account has only + // one SSO identity, so we are not encoding the provider at + // this time + } + + return aal, amr, nil +} + +func (s *Session) GetAAL() string { + if s.AAL == nil { + return "" + } + return *(s.AAL) +} + +func (s *Session) IsAAL2() bool { + return s.GetAAL() == AAL2.String() +} + +// FindCurrentlyActiveRefreshToken returns the currently active refresh +// token in the session. This is the last created (ordered by the serial +// primary key) non-revoked refresh token for the session. +func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*RefreshToken, error) { + var activeRefreshToken RefreshToken + + if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) { + return nil, RefreshTokenNotFoundError{} + } + + return nil, err + } + + return &activeRefreshToken, nil +} diff --git a/internal/models/sessions_test.go b/internal/models/sessions_test.go new file mode 100644 index 000000000..9dce78e95 --- /dev/null +++ b/internal/models/sessions_test.go @@ -0,0 +1,104 @@ +package models + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type SessionsTestSuite struct { + suite.Suite + db *storage.Connection + Config *conf.GlobalConfiguration +} + +func (ts *SessionsTestSuite) SetupTest() { + TruncateAll(ts.db) + email := "test@example.com" + user, err := NewUser("", email, "secret", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + + err = ts.db.Create(user) + require.NoError(ts.T(), err) +} + +func TestSession(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + ts := &SessionsTestSuite{ + db: conn, + Config: globalConfig, + } + defer ts.db.Close() + suite.Run(t, ts) +} + +func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() { + u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + session, err := NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(session)) + + found, err := FindSessionByID(ts.db, session.ID, true) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), session.ID, found.ID) +} + +func (ts *SessionsTestSuite) AddClaimAndReloadSession(session *Session, claim AuthenticationMethod) *Session { + err := AddClaimToSession(ts.db, session.ID, claim) + require.NoError(ts.T(), err) + session, err = FindSessionByID(ts.db, session.ID, false) + require.NoError(ts.T(), err) + return session +} + +func (ts *SessionsTestSuite) TestCalculateAALAndAMR() { + totalDistinctClaims := 3 + u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + session, err := NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(session)) + + session = ts.AddClaimAndReloadSession(session, PasswordGrant) + + firstClaimAddedTime := time.Now() + session = ts.AddClaimAndReloadSession(session, TOTPSignIn) + + _, _, err = session.CalculateAALAndAMR(u) + require.NoError(ts.T(), err) + + session = ts.AddClaimAndReloadSession(session, TOTPSignIn) + + session = ts.AddClaimAndReloadSession(session, SSOSAML) + + aal, amr, err := session.CalculateAALAndAMR(u) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), AAL2, aal) + require.Equal(ts.T(), totalDistinctClaims, len(amr)) + + found := false + for _, claim := range session.AMRClaims { + if claim.GetAuthenticationMethod() == TOTPSignIn.String() { + require.True(ts.T(), firstClaimAddedTime.Before(claim.UpdatedAt)) + found = true + } + } + + for _, claim := range amr { + if claim.Method == SSOSAML.String() { + require.NotNil(ts.T(), claim.Provider) + } + } + require.True(ts.T(), found) +} diff --git a/internal/models/sso.go b/internal/models/sso.go new file mode 100644 index 000000000..28c2429ac --- /dev/null +++ b/internal/models/sso.go @@ -0,0 +1,262 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type SSOProvider struct { + ID uuid.UUID `db:"id" json:"id"` + + SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` + SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` + + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (p SSOProvider) TableName() string { + return "sso_providers" +} + +func (p SSOProvider) Type() string { + return "saml" +} + +type SAMLAttribute struct { + Name string `json:"name,omitempty"` + Names []string `json:"names,omitempty"` + Default interface{} `json:"default,omitempty"` + Array bool `json:"array,omitempty"` +} + +type SAMLAttributeMapping struct { + Keys map[string]SAMLAttribute `json:"keys,omitempty"` +} + +func (m *SAMLAttributeMapping) Equal(o *SAMLAttributeMapping) bool { + if m == o { + return true + } + + if m == nil || o == nil { + return false + } + + if m.Keys == nil && o.Keys == nil { + return true + } + + if len(m.Keys) != len(o.Keys) { + return false + } + + for mkey, mvalue := range m.Keys { + value, ok := o.Keys[mkey] + if !ok { + return false + } + + if mvalue.Name != value.Name || len(mvalue.Names) != len(value.Names) { + return false + } + + for i := 0; i < len(mvalue.Names); i += 1 { + if mvalue.Names[i] != value.Names[i] { + return false + } + } + + if !reflect.DeepEqual(mvalue.Default, value.Default) { + return false + } + + if mvalue.Array != value.Array { + return false + } + } + + return true +} + +func (m *SAMLAttributeMapping) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("scan source was not []byte") + } + err := json.Unmarshal(b, m) + if err != nil { + return err + } + return nil +} + +func (m SAMLAttributeMapping) Value() (driver.Value, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, err + } + return string(b), nil +} + +type SAMLProvider struct { + ID uuid.UUID `db:"id" json:"-"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + + EntityID string `db:"entity_id" json:"entity_id"` + MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"` + MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"` + + AttributeMapping SAMLAttributeMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"` + + NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` +} + +func (p SAMLProvider) TableName() string { + return "saml_providers" +} + +func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) { + return samlsp.ParseMetadata([]byte(p.MetadataXML)) +} + +type SSODomain struct { + ID uuid.UUID `db:"id" json:"-"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + + Domain string `db:"domain" json:"domain"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` +} + +func (d SSODomain) TableName() string { + return "sso_domains" +} + +type SAMLRelayState struct { + ID uuid.UUID `db:"id"` + + SSOProviderID uuid.UUID `db:"sso_provider_id"` + + RequestID string `db:"request_id"` + ForEmail *string `db:"for_email"` + + RedirectTo string `db:"redirect_to"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` + FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"` + FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"` +} + +func (s SAMLRelayState) TableName() string { + return "saml_relay_states" +} + +func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) { + var samlProvider SAMLProvider + if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID") + } + + var ssoProvider SSOProvider + if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil { + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)") + } + + return &ssoProvider, nil +} + +func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) { + var ssoProvider SSOProvider + + if err := tx.Eager().Q().Where("id = ?", id).First(&ssoProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID") + } + + return &ssoProvider, nil +} + +func FindSSOProviderForEmailAddress(tx *storage.Connection, emailAddress string) (*SSOProvider, error) { + parts := strings.Split(emailAddress, "@") + emailDomain := strings.ToLower(parts[1]) + + return FindSSOProviderByDomain(tx, emailDomain) +} + +func FindSSOProviderByDomain(tx *storage.Connection, domain string) (*SSOProvider, error) { + var ssoDomain SSODomain + + if err := tx.Q().Where("domain = ?", domain).First(&ssoDomain); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO domain") + } + + var ssoProvider SSOProvider + if err := tx.Eager().Q().Where("id = ?", ssoDomain.SSOProviderID).First(&ssoProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via domain)") + } + + return &ssoProvider, nil +} + +func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) { + var providers []SSOProvider + + if err := tx.Eager().All(&providers); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, errors.Wrap(err, "error loading all SAML SSO providers") + } + + return providers, nil +} + +func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) { + var state SAMLRelayState + + if err := tx.Eager().Q().Where("id = ?", id).First(&state); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SAMLRelayStateNotFoundError{} + } + + return nil, errors.Wrap(err, "error loading SAML Relay State") + } + + return &state, nil +} diff --git a/internal/models/sso_test.go b/internal/models/sso_test.go new file mode 100644 index 000000000..b6c965630 --- /dev/null +++ b/internal/models/sso_test.go @@ -0,0 +1,232 @@ +package models + +import ( + tst "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type SSOTestSuite struct { + suite.Suite + + db *storage.Connection +} + +func (ts *SSOTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestSSO(t *tst.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &SSOTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *SSOTestSuite) TestConstraints() { + type exampleSpec struct { + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "", + MetadataXML: "", + }, + }, + }, + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + }, + }, + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "", + }, + }, + }, + }, + } + + for i, example := range examples { + require.Error(ts.T(), ts.db.Eager().Create(example.Provider), "Example %d should have failed with error", i) + } +} + +func (ts *SSOTestSuite) TestDomainUniqueness() { + require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata1", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) + + require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata2", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) +} + +func (ts *SSOTestSuite) TestEntityIDUniqueness() { + require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) + + require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.net", + }, + }, + })) +} + +func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() { + provider := &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + { + Domain: "example.org", + }, + }, + } + + require.NoError(ts.T(), ts.db.Eager().Create(provider), "provider creation failed") + + type exampleSpec struct { + Address string + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + Address: "someone@example.com", + Provider: provider, + }, + { + Address: "someone@example.org", + Provider: provider, + }, + { + Address: "someone@example.net", + Provider: nil, + }, + } + + for i, example := range examples { + rp, err := FindSSOProviderForEmailAddress(ts.db, example.Address) + + if nil == example.Provider { + require.Nil(ts.T(), rp) + require.True(ts.T(), IsNotFoundError(err), "Example %d failed with error %w", i, err) + } else { + require.Nil(ts.T(), err, "Example %d failed with error %w", i, err) + require.Equal(ts.T(), rp.ID, example.Provider.ID) + } + } +} + +func (ts *SSOTestSuite) TestFindSAMLProviderByEntityID() { + provider := &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + { + Domain: "example.org", + }, + }, + } + + require.NoError(ts.T(), ts.db.Eager().Create(provider)) + + type exampleSpec struct { + EntityID string + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + EntityID: "https://example.com/saml/metadata", + Provider: provider, + }, + { + EntityID: "https://example.com/saml/metadata/", + Provider: nil, + }, + { + EntityID: "", + Provider: nil, + }, + } + + for i, example := range examples { + rp, err := FindSAMLProviderByEntityID(ts.db, example.EntityID) + + if nil == example.Provider { + require.True(ts.T(), IsNotFoundError(err), "Example %d failed with error", i) + require.Nil(ts.T(), rp) + } else { + require.Nil(ts.T(), err, "Example %d failed with error %w", i, err) + require.Equal(ts.T(), rp.ID, example.Provider.ID) + } + } +} diff --git a/internal/models/user.go b/internal/models/user.go new file mode 100644 index 000000000..01e728b95 --- /dev/null +++ b/internal/models/user.go @@ -0,0 +1,994 @@ +package models + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "golang.org/x/crypto/bcrypt" +) + +// User respresents a registered user with email/password authentication +type User struct { + ID uuid.UUID `json:"id" db:"id"` + + Aud string `json:"aud" db:"aud"` + Role string `json:"role" db:"role"` + Email storage.NullString `json:"email" db:"email"` + IsSSOUser bool `json:"-" db:"is_sso_user"` + + EncryptedPassword *string `json:"-" db:"encrypted_password"` + EmailConfirmedAt *time.Time `json:"email_confirmed_at,omitempty" db:"email_confirmed_at"` + InvitedAt *time.Time `json:"invited_at,omitempty" db:"invited_at"` + + Phone storage.NullString `json:"phone" db:"phone"` + PhoneConfirmedAt *time.Time `json:"phone_confirmed_at,omitempty" db:"phone_confirmed_at"` + + ConfirmationToken string `json:"-" db:"confirmation_token"` + ConfirmationSentAt *time.Time `json:"confirmation_sent_at,omitempty" db:"confirmation_sent_at"` + + PhoneConfirmationSentAt *time.Time `json:"phone_confirmation_sent_at,omitempty" db:"phone_confirmation_sent_at"` + + // For backward compatibility only. Use EmailConfirmedAt or PhoneConfirmedAt instead. + ConfirmedAt *time.Time `json:"confirmed_at,omitempty" db:"confirmed_at" rw:"r"` + + RecoveryToken string `json:"-" db:"recovery_token"` + RecoverySentAt *time.Time `json:"recovery_sent_at,omitempty" db:"recovery_sent_at"` + + EmailChangeTokenCurrent string `json:"-" db:"email_change_token_current"` + EmailChangeTokenNew string `json:"-" db:"email_change_token_new"` + EmailChange string `json:"new_email,omitempty" db:"email_change"` + EmailChangeSentAt *time.Time `json:"email_change_sent_at,omitempty" db:"email_change_sent_at"` + EmailChangeConfirmStatus int `json:"-" db:"email_change_confirm_status"` + + PhoneChangeToken string `json:"-" db:"phone_change_token"` + PhoneChange string `json:"new_phone,omitempty" db:"phone_change"` + PhoneChangeSentAt *time.Time `json:"phone_change_sent_at,omitempty" db:"phone_change_sent_at"` + + ReauthenticationToken string `json:"-" db:"reauthentication_token"` + ReauthenticationSentAt *time.Time `json:"reauthentication_sent_at,omitempty" db:"reauthentication_sent_at"` + + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` + + AppMetaData JSONMap `json:"app_metadata" db:"raw_app_meta_data"` + UserMetaData JSONMap `json:"user_metadata" db:"raw_user_meta_data"` + + Factors []Factor `json:"factors,omitempty" has_many:"factors"` + Identities []Identity `json:"identities" has_many:"identities"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + BannedUntil *time.Time `json:"banned_until,omitempty" db:"banned_until"` + DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` + IsAnonymous bool `json:"is_anonymous" db:"is_anonymous"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func NewUserWithPasswordHash(phone, email, passwordHash, aud string, userData map[string]interface{}) (*User, error) { + if strings.HasPrefix(passwordHash, crypto.Argon2Prefix) { + _, err := crypto.ParseArgon2Hash(passwordHash) + if err != nil { + return nil, err + } + } else if strings.HasPrefix(passwordHash, crypto.FirebaseScryptPrefix) { + _, err := crypto.ParseFirebaseScryptHash(passwordHash) + if err != nil { + return nil, err + } + } else { + // verify that the hash is a bcrypt hash + _, err := bcrypt.Cost([]byte(passwordHash)) + if err != nil { + return nil, err + } + } + id := uuid.Must(uuid.NewV4()) + user := &User{ + ID: id, + Aud: aud, + Email: storage.NullString(strings.ToLower(email)), + Phone: storage.NullString(phone), + UserMetaData: userData, + EncryptedPassword: &passwordHash, + } + return user, nil +} + +// NewUser initializes a new user from an email, password and user data. +func NewUser(phone, email, password, aud string, userData map[string]interface{}) (*User, error) { + passwordHash := "" + + if password != "" { + pw, err := crypto.GenerateFromPassword(context.Background(), password) + if err != nil { + return nil, err + } + + passwordHash = pw + } + + if userData == nil { + userData = make(map[string]interface{}) + } + + id := uuid.Must(uuid.NewV4()) + user := &User{ + ID: id, + Aud: aud, + Email: storage.NullString(strings.ToLower(email)), + Phone: storage.NullString(phone), + UserMetaData: userData, + EncryptedPassword: &passwordHash, + } + return user, nil +} + +// TableName overrides the table name used by pop +func (User) TableName() string { + tableName := "users" + return tableName +} + +func (u *User) HasPassword() bool { + var pwd string + + if u.EncryptedPassword != nil { + pwd = *u.EncryptedPassword + } + + return pwd != "" +} + +// BeforeSave is invoked before the user is saved to the database +func (u *User) BeforeSave(tx *pop.Connection) error { + if u.EmailConfirmedAt != nil && u.EmailConfirmedAt.IsZero() { + u.EmailConfirmedAt = nil + } + if u.PhoneConfirmedAt != nil && u.PhoneConfirmedAt.IsZero() { + u.PhoneConfirmedAt = nil + } + if u.InvitedAt != nil && u.InvitedAt.IsZero() { + u.InvitedAt = nil + } + if u.ConfirmationSentAt != nil && u.ConfirmationSentAt.IsZero() { + u.ConfirmationSentAt = nil + } + if u.PhoneConfirmationSentAt != nil && u.PhoneConfirmationSentAt.IsZero() { + u.PhoneConfirmationSentAt = nil + } + if u.RecoverySentAt != nil && u.RecoverySentAt.IsZero() { + u.RecoverySentAt = nil + } + if u.EmailChangeSentAt != nil && u.EmailChangeSentAt.IsZero() { + u.EmailChangeSentAt = nil + } + if u.PhoneChangeSentAt != nil && u.PhoneChangeSentAt.IsZero() { + u.PhoneChangeSentAt = nil + } + if u.ReauthenticationSentAt != nil && u.ReauthenticationSentAt.IsZero() { + u.ReauthenticationSentAt = nil + } + if u.LastSignInAt != nil && u.LastSignInAt.IsZero() { + u.LastSignInAt = nil + } + if u.BannedUntil != nil && u.BannedUntil.IsZero() { + u.BannedUntil = nil + } + return nil +} + +// IsConfirmed checks if a user has already been +// registered and confirmed. +func (u *User) IsConfirmed() bool { + return u.EmailConfirmedAt != nil +} + +// HasBeenInvited checks if user has been invited +func (u *User) HasBeenInvited() bool { + return u.InvitedAt != nil +} + +// IsPhoneConfirmed checks if a user's phone has already been +// registered and confirmed. +func (u *User) IsPhoneConfirmed() bool { + return u.PhoneConfirmedAt != nil +} + +// SetRole sets the users Role to roleName +func (u *User) SetRole(tx *storage.Connection, roleName string) error { + u.Role = strings.TrimSpace(roleName) + return tx.UpdateOnly(u, "role") +} + +// HasRole returns true when the users role is set to roleName +func (u *User) HasRole(roleName string) bool { + return u.Role == roleName +} + +// GetEmail returns the user's email as a string +func (u *User) GetEmail() string { + return string(u.Email) +} + +// GetPhone returns the user's phone number as a string +func (u *User) GetPhone() string { + return string(u.Phone) +} + +// UpdateUserMetaData sets all user data from a map of updates, +// ensuring that it doesn't override attributes that are not +// in the provided map. +func (u *User) UpdateUserMetaData(tx *storage.Connection, updates map[string]interface{}) error { + if u.UserMetaData == nil { + u.UserMetaData = updates + } else { + for key, value := range updates { + if value != nil { + u.UserMetaData[key] = value + } else { + delete(u.UserMetaData, key) + } + } + } + return tx.UpdateOnly(u, "raw_user_meta_data") +} + +// UpdateAppMetaData updates all app data from a map of updates +func (u *User) UpdateAppMetaData(tx *storage.Connection, updates map[string]interface{}) error { + if u.AppMetaData == nil { + u.AppMetaData = updates + } else { + for key, value := range updates { + if value != nil { + u.AppMetaData[key] = value + } else { + delete(u.AppMetaData, key) + } + } + } + return tx.UpdateOnly(u, "raw_app_meta_data") +} + +// UpdateAppMetaDataProviders updates the provider field in AppMetaData column +func (u *User) UpdateAppMetaDataProviders(tx *storage.Connection) error { + providers, terr := FindProvidersByUser(tx, u) + if terr != nil { + return terr + } + payload := map[string]interface{}{ + "providers": providers, + } + if len(providers) > 0 { + payload["provider"] = providers[0] + } + return u.UpdateAppMetaData(tx, payload) +} + +// UpdateUserEmail updates the user's email to one of the identity's email +// if the current email used doesn't match any of the identities email +func (u *User) UpdateUserEmailFromIdentities(tx *storage.Connection) error { + identities, terr := FindIdentitiesByUserID(tx, u.ID) + if terr != nil { + return terr + } + for _, i := range identities { + if u.GetEmail() == i.GetEmail() { + // there's an existing identity that uses the same email + // so the user's email can be kept + return nil + } + } + + var primaryIdentity *Identity + for _, i := range identities { + if _, terr := FindUserByEmailAndAudience(tx, i.GetEmail(), u.Aud); terr != nil { + if IsNotFoundError(terr) { + // the identity's email is not used by another user + // so we can set it as the primary identity + primaryIdentity = i + break + } + return terr + } + } + if primaryIdentity == nil { + return UserEmailUniqueConflictError{} + } + // default to the first identity's email + if terr := u.SetEmail(tx, primaryIdentity.GetEmail()); terr != nil { + return terr + } + if primaryIdentity.GetEmail() == "" { + u.EmailConfirmedAt = nil + if terr := tx.UpdateOnly(u, "email_confirmed_at"); terr != nil { + return terr + } + } + return nil +} + +// SetEmail sets the user's email +func (u *User) SetEmail(tx *storage.Connection, email string) error { + u.Email = storage.NullString(email) + return tx.UpdateOnly(u, "email") +} + +// SetPhone sets the user's phone +func (u *User) SetPhone(tx *storage.Connection, phone string) error { + u.Phone = storage.NullString(phone) + return tx.UpdateOnly(u, "phone") +} + +func (u *User) SetPassword(ctx context.Context, password string, encrypt bool, encryptionKeyID, encryptionKey string) error { + if password == "" { + u.EncryptedPassword = nil + return nil + } + + pw, err := crypto.GenerateFromPassword(ctx, password) + if err != nil { + return err + } + + u.EncryptedPassword = &pw + if encrypt { + es, err := crypto.NewEncryptedString(u.ID.String(), []byte(pw), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + encryptedPassword := es.String() + u.EncryptedPassword = &encryptedPassword + } + + return nil +} + +// UpdatePassword updates the user's password. Use SetPassword outside of a transaction first! +func (u *User) UpdatePassword(tx *storage.Connection, sessionID *uuid.UUID) error { + // These need to be reset because password change may mean the user no longer trusts the actions performed by the previous password. + u.ConfirmationToken = "" + u.ConfirmationSentAt = nil + u.RecoveryToken = "" + u.RecoverySentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.EmailChangeSentAt = nil + u.PhoneChangeToken = "" + u.PhoneChangeSentAt = nil + u.ReauthenticationToken = "" + u.ReauthenticationSentAt = nil + + if err := tx.UpdateOnly(u, "encrypted_password", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_current", "email_change_token_new", "email_change_sent_at", "phone_change_token", "phone_change_sent_at", "reauthentication_token", "reauthentication_sent_at"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + if sessionID == nil { + // log out user from all sessions to ensure reauthentication after password change + return Logout(tx, u.ID) + } else { + // log out user from all other sessions to ensure reauthentication after password change + return LogoutAllExceptMe(tx, *sessionID, u.ID) + } +} + +// Authenticate a user from a password +func (u *User) Authenticate(ctx context.Context, tx *storage.Connection, password string, decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (bool, bool, error) { + if u.EncryptedPassword == nil { + return false, false, nil + } + + hash := *u.EncryptedPassword + + if hash == "" { + return false, false, nil + } + + es := crypto.ParseEncryptedString(hash) + if es != nil { + h, err := es.Decrypt(u.ID.String(), decryptionKeys) + if err != nil { + return false, false, err + } + + hash = string(h) + } + + compareErr := crypto.CompareHashAndPassword(ctx, hash, password) + + if !strings.HasPrefix(hash, crypto.Argon2Prefix) && !strings.HasPrefix(hash, crypto.FirebaseScryptPrefix) { + // check if cost exceeds default cost or is too low + cost, err := bcrypt.Cost([]byte(hash)) + if err != nil { + return compareErr == nil, false, err + } + + if cost > bcrypt.DefaultCost || cost == bcrypt.MinCost { + // don't bother with encrypting the password in Authenticate + // since it's handled separately + if err := u.SetPassword(ctx, password, false, "", ""); err != nil { + return compareErr == nil, false, err + } + } + } + + return compareErr == nil, encrypt && (es == nil || es.ShouldReEncrypt(encryptionKeyID)), nil +} + +// ConfirmReauthentication resets the reauthentication token +func (u *User) ConfirmReauthentication(tx *storage.Connection) error { + u.ReauthenticationToken = "" + if err := tx.UpdateOnly(u, "reauthentication_token"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil +} + +// Confirm resets the confimation token and sets the confirm timestamp +func (u *User) Confirm(tx *storage.Connection) error { + u.ConfirmationToken = "" + now := time.Now() + u.EmailConfirmedAt = &now + + if err := tx.UpdateOnly(u, "confirmation_token", "email_confirmed_at"); err != nil { + return err + } + + if err := u.UpdateUserMetaData(tx, map[string]interface{}{ + "email_verified": true, + }); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil +} + +// ConfirmPhone resets the confimation token and sets the confirm timestamp +func (u *User) ConfirmPhone(tx *storage.Connection) error { + u.ConfirmationToken = "" + now := time.Now() + u.PhoneConfirmedAt = &now + if err := tx.UpdateOnly(u, "confirmation_token", "phone_confirmed_at"); err != nil { + return err + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) +} + +// UpdateLastSignInAt update field last_sign_in_at for user according to specified field +func (u *User) UpdateLastSignInAt(tx *storage.Connection) error { + return tx.UpdateOnly(u, "last_sign_in_at") +} + +// ConfirmEmailChange confirm the change of email for a user +func (u *User) ConfirmEmailChange(tx *storage.Connection, status int) error { + email := u.EmailChange + + u.Email = storage.NullString(email) + u.EmailChange = "" + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.EmailChangeConfirmStatus = status + + if err := tx.UpdateOnly( + u, + "email", + "email_change", + "email_change_token_current", + "email_change_token_new", + "email_change_confirm_status", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + if !u.IsConfirmed() { + if err := u.Confirm(tx); err != nil { + return err + } + } + + identity, err := FindIdentityByIdAndProvider(tx, u.ID.String(), "email") + if err != nil { + if IsNotFoundError(err) { + // no email identity, not an error + return nil + } + return err + } + + if _, ok := identity.IdentityData["email"]; ok { + identity.IdentityData["email"] = email + if err := tx.UpdateOnly(identity, "identity_data"); err != nil { + return err + } + } + + return nil +} + +// ConfirmPhoneChange confirms the change of phone for a user +func (u *User) ConfirmPhoneChange(tx *storage.Connection) error { + now := time.Now() + phone := u.PhoneChange + + u.Phone = storage.NullString(phone) + u.PhoneChange = "" + u.PhoneChangeToken = "" + u.PhoneConfirmedAt = &now + + if err := tx.UpdateOnly( + u, + "phone", + "phone_change", + "phone_change_token", + "phone_confirmed_at", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + identity, err := FindIdentityByIdAndProvider(tx, u.ID.String(), "phone") + if err != nil { + if IsNotFoundError(err) { + // no phone identity, not an error + return nil + } + + return err + } + + if _, ok := identity.IdentityData["phone"]; ok { + identity.IdentityData["phone"] = phone + } + + if err := tx.UpdateOnly(identity, "identity_data"); err != nil { + return err + } + + return nil +} + +// Recover resets the recovery token +func (u *User) Recover(tx *storage.Connection) error { + u.RecoveryToken = "" + if err := tx.UpdateOnly(u, "recovery_token"); err != nil { + return err + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) +} + +// CountOtherUsers counts how many other users exist besides the one provided +func CountOtherUsers(tx *storage.Connection, id uuid.UUID) (int, error) { + userCount, err := tx.Q().Where("instance_id = ? and id != ?", uuid.Nil, id).Count(&User{}) + return userCount, errors.Wrap(err, "error finding registered users") +} + +func findUser(tx *storage.Connection, query string, args ...interface{}) (*User, error) { + obj := &User{} + if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, UserNotFoundError{} + } + return nil, errors.Wrap(err, "error finding user") + } + + return obj, nil +} + +// FindUserByEmailAndAudience finds a user with the matching email and audience. +func FindUserByEmailAndAudience(tx *storage.Connection, email, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false", uuid.Nil, strings.ToLower(email), aud) +} + +// FindUserByPhoneAndAudience finds a user with the matching email and audience. +func FindUserByPhoneAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and phone = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) +} + +// FindUserByID finds a user matching the provided ID. +func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { + return findUser(tx, "instance_id = ? and id = ?", uuid.Nil, id) +} + +// FindUserWithRefreshToken finds a user from the provided refresh token. If +// forUpdate is set to true, then the SELECT statement used by the query has +// the form SELECT ... FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE +// lock will only be acquired if there's no other lock. In case there is a +// lock, a IsNotFound(err) error will be returned. +func FindUserWithRefreshToken(tx *storage.Connection, token string, forUpdate bool) (*User, *RefreshToken, *Session, error) { + refreshToken := &RefreshToken{} + + if forUpdate { + // pop does not provide us with a way to execute FOR UPDATE + // queries which lock the rows affected by the query from + // being accessed by any other transaction that also uses FOR + // UPDATE + if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE token = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", refreshToken.TableName()), token).First(refreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil, nil, RefreshTokenNotFoundError{} + } + + return nil, nil, nil, errors.Wrap(err, "error finding refresh token for update") + } + } + + // once the rows are locked (if forUpdate was true), we can query again using pop + if err := tx.Where("token = ?", token).First(refreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil, nil, RefreshTokenNotFoundError{} + } + return nil, nil, nil, errors.Wrap(err, "error finding refresh token") + } + + user, err := FindUserByID(tx, refreshToken.UserID) + if err != nil { + return nil, nil, nil, err + } + + var session *Session + + if refreshToken.SessionId != nil { + sessionId := *refreshToken.SessionId + + if sessionId != uuid.Nil { + session, err = FindSessionByID(tx, sessionId, forUpdate) + if err != nil { + if forUpdate { + return nil, nil, nil, err + } + + if !IsNotFoundError(err) { + return nil, nil, nil, errors.Wrap(err, "error finding session from refresh token") + } + + // otherwise, there's no session for this refresh token + } + } + } + + return user, refreshToken, session, nil +} + +// FindUsersInAudience finds users with the matching audience. +func FindUsersInAudience(tx *storage.Connection, aud string, pageParams *Pagination, sortParams *SortParams, filter string) ([]*User, error) { + users := []*User{} + q := tx.Q().Where("instance_id = ? and aud = ?", uuid.Nil, aud) + + if filter != "" { + lf := "%" + filter + "%" + // we must specify the collation in order to get case insensitive search for the JSON column + q = q.Where("(email LIKE ? OR raw_user_meta_data->>'full_name' ILIKE ?)", lf, lf) + } + + if sortParams != nil && len(sortParams.Fields) > 0 { + for _, field := range sortParams.Fields { + q = q.Order(field.Name + " " + string(field.Dir)) + } + } + + var err error + if pageParams != nil { + err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&users) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + } else { + err = q.All(&users) + } + + return users, err +} + +// IsDuplicatedEmail returns whether a user exists with a matching email and audience. +// If a currentUser is provided, we will need to filter out any identities that belong to the current user. +func IsDuplicatedEmail(tx *storage.Connection, email, aud string, currentUser *User) (*User, error) { + var identities []Identity + + if err := tx.Eager().Q().Where("email = ?", strings.ToLower(email)).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, errors.Wrap(err, "unable to find identity by email for duplicates") + } + + userIDs := make(map[string]uuid.UUID) + for _, identity := range identities { + if _, ok := userIDs[identity.UserID.String()]; !ok { + if !identity.IsForSSOProvider() { + userIDs[identity.UserID.String()] = identity.UserID + } + } + } + + var currentUserId uuid.UUID + if currentUser != nil { + currentUserId = currentUser.ID + } + + for _, userID := range userIDs { + if userID != currentUserId { + user, err := FindUserByID(tx, userID) + if err != nil { + return nil, errors.Wrap(err, "unable to find user from email identity for duplicates") + } + if user.Aud == aud { + return user, nil + } + } + } + + // out of an abundance of caution, if nothing was found via the + // identities table we also do a final check on the users table + user, err := FindUserByEmailAndAudience(tx, email, aud) + if err != nil && !IsNotFoundError(err) { + return nil, errors.Wrap(err, "unable to find user email address for duplicates") + } + + return user, nil +} + +// IsDuplicatedPhone checks if the phone number already exists in the users table +func IsDuplicatedPhone(tx *storage.Connection, phone, aud string) (bool, error) { + _, err := FindUserByPhoneAndAudience(tx, phone, aud) + if err != nil { + if IsNotFoundError(err) { + return false, nil + } + return false, err + } + return true, nil +} + +// Ban a user for a given duration. +func (u *User) Ban(tx *storage.Connection, duration time.Duration) error { + if duration == time.Duration(0) { + u.BannedUntil = nil + } else { + t := time.Now().Add(duration) + u.BannedUntil = &t + } + return tx.UpdateOnly(u, "banned_until") +} + +// IsBanned checks if a user is banned or not +func (u *User) IsBanned() bool { + if u.BannedUntil == nil { + return false + } + return time.Now().Before(*u.BannedUntil) +} + +func (u *User) HasMFAEnabled() bool { + for _, factor := range u.Factors { + if factor.IsVerified() { + return true + } + } + + return false +} + +func (u *User) UpdateBannedUntil(tx *storage.Connection) error { + return tx.UpdateOnly(u, "banned_until") +} + +// RemoveUnconfirmedIdentities removes potentially malicious unconfirmed identities from a user (if any) +func (u *User) RemoveUnconfirmedIdentities(tx *storage.Connection, identity *Identity) error { + if identity.Provider != "email" && identity.Provider != "phone" { + // user is unconfirmed so the password should be reset + u.EncryptedPassword = nil + if terr := tx.UpdateOnly(u, "encrypted_password"); terr != nil { + return terr + } + } + + // user is unconfirmed so existing user_metadata should be overwritten + // to use the current identity metadata + u.UserMetaData = identity.IdentityData + if terr := u.UpdateUserMetaData(tx, u.UserMetaData); terr != nil { + return terr + } + + // finally, remove all identities except the current identity being authenticated + for i := range u.Identities { + if u.Identities[i].ID != identity.ID { + if terr := tx.Destroy(&u.Identities[i]); terr != nil { + return terr + } + } + } + + // user is unconfirmed so none of the providers associated to it are verified yet + // only the current provider should be kept + if terr := u.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + return nil +} + +// SoftDeleteUser performs a soft deletion on the user by obfuscating and clearing certain fields +func (u *User) SoftDeleteUser(tx *storage.Connection) error { + u.Email = storage.NullString(obfuscateEmail(u, u.GetEmail())) + u.Phone = storage.NullString(obfuscatePhone(u, u.GetPhone())) + u.EmailChange = obfuscateEmail(u, u.EmailChange) + u.PhoneChange = obfuscatePhone(u, u.PhoneChange) + u.EncryptedPassword = nil + u.ConfirmationToken = "" + u.RecoveryToken = "" + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.PhoneChangeToken = "" + + // set deleted_at time + now := time.Now() + u.DeletedAt = &now + + if err := tx.UpdateOnly( + u, + "email", + "phone", + "encrypted_password", + "email_change", + "phone_change", + "confirmation_token", + "recovery_token", + "email_change_token_current", + "email_change_token_new", + "phone_change_token", + "deleted_at", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + // set raw_user_meta_data to {} + userMetaDataUpdates := map[string]interface{}{} + for k := range u.UserMetaData { + userMetaDataUpdates[k] = nil + } + + if err := u.UpdateUserMetaData(tx, userMetaDataUpdates); err != nil { + return err + } + + // set raw_app_meta_data to {} + appMetaDataUpdates := map[string]interface{}{} + for k := range u.AppMetaData { + appMetaDataUpdates[k] = nil + } + + if err := u.UpdateAppMetaData(tx, appMetaDataUpdates); err != nil { + return err + } + + if err := Logout(tx, u.ID); err != nil { + return err + } + + return nil +} + +// SoftDeleteUserIdentities performs a soft deletion on all identities associated to a user +func (u *User) SoftDeleteUserIdentities(tx *storage.Connection) error { + identities, err := FindIdentitiesByUserID(tx, u.ID) + if err != nil { + return err + } + + // set identity_data to {} + for _, identity := range identities { + identityDataUpdates := map[string]interface{}{} + for k := range identity.IdentityData { + identityDataUpdates[k] = nil + } + if err := identity.UpdateIdentityData(tx, identityDataUpdates); err != nil { + return err + } + // updating the identity.ID has to happen last since the primary key is on (provider, id) + // we use RawQuery here instead of UpdateOnly because UpdateOnly relies on the primary key of Identity + if err := tx.RawQuery( + "update "+ + (&pop.Model{Value: Identity{}}).TableName()+ + " set provider_id = ? where id = ?", + obfuscateIdentityProviderId(identity), + identity.ID, + ).Exec(); err != nil { + return err + } + } + return nil +} + +func (u *User) FindOwnedFactorByID(tx *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := tx.Where("user_id = ? AND id = ?", u.ID, factorID).First(&factor) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, &FactorNotFoundError{} + } + return nil, err + } + return &factor, nil +} + +func (user *User) WebAuthnID() []byte { + return []byte(user.ID.String()) +} + +func (user *User) WebAuthnName() string { + return user.Email.String() +} + +func (user *User) WebAuthnDisplayName() string { + return user.Email.String() +} + +func (user *User) WebAuthnCredentials() []webauthn.Credential { + var credentials []webauthn.Credential + + for _, factor := range user.Factors { + if factor.IsVerified() && factor.FactorType == WebAuthn { + credential := factor.WebAuthnCredential.Credential + credentials = append(credentials, credential) + } + } + + return credentials +} + +func obfuscateValue(id uuid.UUID, value string) string { + hash := sha256.Sum256([]byte(id.String() + value)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +func obfuscateEmail(u *User, email string) string { + return obfuscateValue(u.ID, email) +} + +func obfuscatePhone(u *User, phone string) string { + // Field converted from VARCHAR(15) to text + return obfuscateValue(u.ID, phone)[:15] +} + +func obfuscateIdentityProviderId(identity *Identity) string { + return obfuscateValue(identity.UserID, identity.Provider+":"+identity.ProviderID) +} + +// FindUserByPhoneChangeAndAudience finds a user with the matching phone change and audience. +func FindUserByPhoneChangeAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and phone_change = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) +} diff --git a/internal/models/user_test.go b/internal/models/user_test.go new file mode 100644 index 000000000..034954389 --- /dev/null +++ b/internal/models/user_test.go @@ -0,0 +1,467 @@ +package models + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" + "golang.org/x/crypto/bcrypt" +) + +const modelsTestConfig = "../../hack/test.env" + +func init() { + crypto.PasswordHashCost = crypto.QuickHashCost +} + +type UserTestSuite struct { + suite.Suite + db *storage.Connection +} + +func (ts *UserTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestUser(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &UserTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *UserTestSuite) TestUpdateAppMetadata() { + u, err := NewUser("", "", "", "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, make(map[string]interface{}))) + + require.NotNil(ts.T(), u.AppMetaData) + + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ + "foo": "bar", + })) + + require.Equal(ts.T(), "bar", u.AppMetaData["foo"]) + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ + "foo": nil, + })) + require.Len(ts.T(), u.AppMetaData, 0) + require.Equal(ts.T(), nil, u.AppMetaData["foo"]) +} + +func (ts *UserTestSuite) TestUpdateUserMetadata() { + u, err := NewUser("", "", "", "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, make(map[string]interface{}))) + + require.NotNil(ts.T(), u.UserMetaData) + + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ + "foo": "bar", + })) + + require.Equal(ts.T(), "bar", u.UserMetaData["foo"]) + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ + "foo": nil, + })) + require.Len(ts.T(), u.UserMetaData, 0) + require.Equal(ts.T(), nil, u.UserMetaData["foo"]) +} + +func (ts *UserTestSuite) TestFindUserByConfirmationToken() { + u := ts.createUser() + tokenHash := "test_confirmation_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, ConfirmationToken)) + + n, err := FindUserByConfirmationToken(ts.db, tokenHash) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserByEmailAndAudience() { + u := ts.createUser() + + n, err := FindUserByEmailAndAudience(ts.db, u.GetEmail(), "test") + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) + + _, err = FindUserByEmailAndAudience(ts.db, u.GetEmail(), "invalid") + require.EqualError(ts.T(), err, UserNotFoundError{}.Error()) +} + +func (ts *UserTestSuite) TestFindUsersInAudience() { + u := ts.createUser() + + n, err := FindUsersInAudience(ts.db, u.Aud, nil, nil, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) + + p := Pagination{ + Page: 1, + PerPage: 50, + } + n, err = FindUsersInAudience(ts.db, u.Aud, &p, nil, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) + assert.Equal(ts.T(), uint64(1), p.Count) + + sp := &SortParams{ + Fields: []SortField{ + {Name: "created_at", Dir: Descending}, + }, + } + n, err = FindUsersInAudience(ts.db, u.Aud, nil, sp, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) +} + +func (ts *UserTestSuite) TestFindUserByID() { + u := ts.createUser() + + n, err := FindUserByID(ts.db, u.ID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserByRecoveryToken() { + u := ts.createUser() + tokenHash := "test_recovery_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, RecoveryToken)) + + n, err := FindUserByRecoveryToken(ts.db, tokenHash) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserWithRefreshToken() { + u := ts.createUser() + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) + require.NoError(ts.T(), err) + + n, nr, s, err := FindUserWithRefreshToken(ts.db, r.Token, true /* forUpdate */) + require.NoError(ts.T(), err) + require.Equal(ts.T(), r.ID, nr.ID) + require.Equal(ts.T(), u.ID, n.ID) + require.NotNil(ts.T(), s) + require.Equal(ts.T(), *r.SessionId, s.ID) +} + +func (ts *UserTestSuite) TestIsDuplicatedEmail() { + _ = ts.createUserWithEmail("david.calavera@netlify.com") + + e, err := IsDuplicatedEmail(ts.db, "david.calavera@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), e, "expected email to be duplicated") + + e, err = IsDuplicatedEmail(ts.db, "davidcalavera@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected email to not be duplicated", nil) + + e, err = IsDuplicatedEmail(ts.db, "david@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected same email to not be duplicated", nil) + + e, err = IsDuplicatedEmail(ts.db, "david.calavera@netlify.com", "other-aud", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected same email to not be duplicated") +} + +func (ts *UserTestSuite) createUser() *User { + return ts.createUserWithEmail("david@netlify.com") +} + +func (ts *UserTestSuite) createUserWithEmail(email string) *User { + user, err := NewUser("", email, "secret", "test", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": user.ID.String(), + "email": email, + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + return user +} + +func (ts *UserTestSuite) TestRemoveUnconfirmedIdentities() { + user, err := NewUser("+29382983298", "someone@example.com", "abcdefgh", "authenticated", nil) + require.NoError(ts.T(), err) + + user.AppMetaData = map[string]interface{}{ + "provider": "email", + "providers": []string{"email", "phone", "twitter"}, + } + + require.NoError(ts.T(), ts.db.Create(user)) + + idEmail, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": "someone@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idEmail)) + + idPhone, err := NewIdentity(user, "phone", map[string]interface{}{ + "sub": "+29382983298", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idPhone)) + + idTwitter, err := NewIdentity(user, "twitter", map[string]interface{}{ + "sub": "test_twitter_user_id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idTwitter)) + + user.Identities = append(user.Identities, *idEmail, *idPhone, *idTwitter) + + // reload the user + require.NoError(ts.T(), ts.db.Load(user)) + + require.False(ts.T(), user.IsConfirmed(), "user's email must not be confirmed") + + require.NoError(ts.T(), user.RemoveUnconfirmedIdentities(ts.db, idTwitter)) + + // reload the user to check that identities are deleted from the db too + require.NoError(ts.T(), ts.db.Load(user)) + require.Empty(ts.T(), user.EncryptedPassword, "password still remains in user") + + require.Len(ts.T(), user.Identities, 1, "only one identity must be remaining") + require.Equal(ts.T(), idTwitter.ID, user.Identities[0].ID, "remaining identity is not the expected one") + + require.NotNil(ts.T(), user.AppMetaData) + require.Equal(ts.T(), user.AppMetaData["provider"], "twitter") + require.Equal(ts.T(), user.AppMetaData["providers"], []string{"twitter"}) +} + +func (ts *UserTestSuite) TestConfirmEmailChange() { + user, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": user.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + user.EmailChange = "new@example.com" + require.NoError(ts.T(), ts.db.UpdateOnly(user, "email_change")) + + require.NoError(ts.T(), user.ConfirmEmailChange(ts.db, 0)) + + require.NoError(ts.T(), ts.db.Eager().Load(user)) + identity, err = FindIdentityByIdAndProvider(ts.db, user.ID.String(), "email") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), user.Email, storage.NullString("new@example.com")) + require.Equal(ts.T(), user.EmailChange, "") + + require.NotNil(ts.T(), identity.IdentityData) + require.Equal(ts.T(), identity.IdentityData["email"], "new@example.com") +} + +func (ts *UserTestSuite) TestConfirmPhoneChange() { + user, err := NewUser("123456789", "", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "phone", map[string]interface{}{ + "sub": user.ID.String(), + "phone": "123456789", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + user.PhoneChange = "987654321" + require.NoError(ts.T(), ts.db.UpdateOnly(user, "phone_change")) + + require.NoError(ts.T(), user.ConfirmPhoneChange(ts.db)) + + require.NoError(ts.T(), ts.db.Eager().Load(user)) + identity, err = FindIdentityByIdAndProvider(ts.db, user.ID.String(), "phone") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), user.Phone, storage.NullString("987654321")) + require.Equal(ts.T(), user.PhoneChange, "") + + require.NotNil(ts.T(), identity.IdentityData) + require.Equal(ts.T(), identity.IdentityData["phone"], "987654321") +} + +func (ts *UserTestSuite) TestUpdateUserEmailSuccess() { + userA, err := NewUser("", "foo@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + + primaryIdentity, err := NewIdentity(userA, "email", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "foo@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(primaryIdentity)) + + secondaryIdentity, err := NewIdentity(userA, "google", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "bar@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(secondaryIdentity)) + + // UpdateUserEmail should not do anything and the user's email should still use the primaryIdentity + require.NoError(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db)) + require.Equal(ts.T(), primaryIdentity.GetEmail(), userA.GetEmail()) + + // remove primary identity + require.NoError(ts.T(), ts.db.Destroy(primaryIdentity)) + + // UpdateUserEmail should update the user to use the secondary identity's email + require.NoError(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db)) + require.Equal(ts.T(), secondaryIdentity.GetEmail(), userA.GetEmail()) +} + +func (ts *UserTestSuite) TestUpdateUserEmailFailure() { + userA, err := NewUser("", "foo@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + + primaryIdentity, err := NewIdentity(userA, "email", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "foo@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(primaryIdentity)) + + secondaryIdentity, err := NewIdentity(userA, "google", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "bar@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(secondaryIdentity)) + + userB, err := NewUser("", "bar@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + // remove primary identity + require.NoError(ts.T(), ts.db.Destroy(primaryIdentity)) + + // UpdateUserEmail should fail with the email unique constraint violation error + // since userB is using the secondary identity's email + require.ErrorIs(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db), UserEmailUniqueConflictError{}) + require.Equal(ts.T(), primaryIdentity.GetEmail(), userA.GetEmail()) +} + +func (ts *UserTestSuite) TestNewUserWithPasswordHashSuccess() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Valid bcrypt hash", + hash: "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + { + desc: "Valid argon2i hash", + hash: "$argon2i$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + }, + { + desc: "Valid argon2id hash", + hash: "$argon2id$v=19$m=32,t=3,p=2$SFVpOWJ0eXhjRzVkdGN1RQ$RXnb8rh7LaDcn07xsssqqulZYXOM/EUCEFMVcAcyYVk", + }, + { + desc: "Valid Firebase scrypt hash", + hash: "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$ZGlmZmVyZW50aGFzaA==", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestNewUserWithPasswordHashFailure() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Invalid argon2i hash", + hash: "$argon2id$test", + }, + { + desc: "Invalid bcrypt hash", + hash: "plaintest_password", + }, + { + desc: "Invalid scrypt hash", + hash: "$fbscrypt$invalid", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.Error(ts.T(), err) + require.Nil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestAuthenticate() { + // every case uses "test" as the password + cases := []struct { + desc string + hash string + expectedHashCost int + }{ + { + desc: "Invalid bcrypt hash cost of 11", + hash: "$2y$11$4lH57PU7bGATpRcx93vIoObH3qDmft/pytbOzDG9/1WsyNmN5u4di", + expectedHashCost: bcrypt.MinCost, + }, + { + desc: "Valid bcrypt hash cost of 10", + hash: "$2y$10$va66S4MxFrH6G6L7BzYl0.QgcYgvSr/F92gc.3botlz7bG4p/g/1i", + expectedHashCost: bcrypt.DefaultCost, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(u)) + require.NotNil(ts.T(), u) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.db, "test", nil, false, "") + require.NoError(ts.T(), err) + require.True(ts.T(), isAuthenticated) + + // check hash cost + hashCost, err := bcrypt.Cost([]byte(*u.EncryptedPassword)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expectedHashCost, hashCost) + }) + } +} diff --git a/internal/models/web3.go b/internal/models/web3.go new file mode 100644 index 000000000..6fc274e16 --- /dev/null +++ b/internal/models/web3.go @@ -0,0 +1,4 @@ +package models + +const Web3Provider = "web3" +const Web3Grant = "web3" diff --git a/internal/observability/cleanup.go b/internal/observability/cleanup.go new file mode 100644 index 000000000..2e88c3590 --- /dev/null +++ b/internal/observability/cleanup.go @@ -0,0 +1,18 @@ +package observability + +import ( + "context" + "sync" + + "github.com/supabase/auth/internal/utilities" +) + +var ( + cleanupWaitGroup sync.WaitGroup +) + +// WaitForCleanup waits until all observability long-running goroutines shut +// down cleanly or until the provided context signals done. +func WaitForCleanup(ctx context.Context) { + utilities.WaitForCleanup(ctx, &cleanupWaitGroup) +} diff --git a/internal/observability/logging.go b/internal/observability/logging.go new file mode 100644 index 000000000..ff8ac96ea --- /dev/null +++ b/internal/observability/logging.go @@ -0,0 +1,125 @@ +package observability + +import ( + "os" + "sync" + "time" + + "github.com/bombsimon/logrusr/v3" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/logging" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "go.opentelemetry.io/otel" +) + +const ( + LOG_SQL_ALL = "all" + LOG_SQL_NONE = "none" + LOG_SQL_STATEMENT = "statement" +) + +var ( + loggingOnce sync.Once +) + +type CustomFormatter struct { + logrus.JSONFormatter +} + +func NewCustomFormatter() *CustomFormatter { + return &CustomFormatter{ + JSONFormatter: logrus.JSONFormatter{ + DisableTimestamp: false, + TimestampFormat: time.RFC3339, + }, + } +} + +func (f *CustomFormatter) Format(entry *logrus.Entry) ([]byte, error) { + // logrus doesn't support formatting the time in UTC so we need to use a custom formatter + entry.Time = entry.Time.UTC() + return f.JSONFormatter.Format(entry) +} + +func ConfigureLogging(config *conf.LoggingConfig) error { + var err error + + loggingOnce.Do(func() { + formatter := NewCustomFormatter() + logrus.SetFormatter(formatter) + + // use a file if you want + if config.File != "" { + f, errOpen := os.OpenFile(config.File, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0660) //#nosec G302 -- Log files should be rw-rw-r-- + if errOpen != nil { + err = errOpen + return + } + logrus.SetOutput(f) + logrus.Infof("Set output file to %s", config.File) + } + + if config.Level != "" { + level, errParse := logrus.ParseLevel(config.Level) + if err != nil { + err = errParse + return + } + logrus.SetLevel(level) + logrus.Debug("Set log level to: " + logrus.GetLevel().String()) + } + + f := logrus.Fields{} + for k, v := range config.Fields { + f[k] = v + } + logrus.WithFields(f) + + setPopLogger(config.SQL) + + otel.SetLogger(logrusr.New(logrus.StandardLogger().WithField("component", "otel"))) + }) + + return err +} + +func setPopLogger(sql string) { + popLog := logrus.WithField("component", "pop") + sqlLog := logrus.WithField("component", "sql") + + shouldLogSQL := sql == LOG_SQL_STATEMENT || sql == LOG_SQL_ALL + shouldLogSQLArgs := sql == LOG_SQL_ALL + + pop.SetLogger(func(lvl logging.Level, s string, args ...interface{}) { + // Special case SQL logging since we have 2 extra flags to check + if lvl == logging.SQL { + if !shouldLogSQL { + return + } + + if shouldLogSQLArgs && len(args) > 0 { + sqlLog.WithField("args", args).Info(s) + } else { + sqlLog.Info(s) + } + return + } + + l := popLog + if len(args) > 0 { + l = l.WithField("args", args) + } + + switch lvl { + case logging.SQL, logging.Debug: + l.Debug(s) + case logging.Info: + l.Info(s) + case logging.Warn: + l.Warn(s) + case logging.Error: + l.Error(s) + } + }) +} diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go new file mode 100644 index 000000000..b3632aa8e --- /dev/null +++ b/internal/observability/metrics.go @@ -0,0 +1,202 @@ +package observability + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/metric" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + otelruntimemetrics "go.opentelemetry.io/contrib/instrumentation/runtime" +) + +func Meter(instrumentationName string, opts ...metric.MeterOption) metric.Meter { + return otel.Meter(instrumentationName, opts...) +} + +func ObtainMetricCounter(name, desc string) metric.Int64Counter { + counter, err := Meter("gotrue").Int64Counter(name, metric.WithDescription(desc)) + if err != nil { + panic(err) + } + return counter +} + +func enablePrometheusMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + exporter, err := prometheus.New() + if err != nil { + return err + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + + otel.SetMeterProvider(provider) + + cleanupWaitGroup.Add(1) + go func() { + addr := net.JoinHostPort(mc.PrometheusListenHost, mc.PrometheusListenPort) + baseContext, cancel := context.WithCancel(context.Background()) + + server := &http.Server{ + Addr: addr, + Handler: promhttp.Handler(), + BaseContext: func(net.Listener) context.Context { + return baseContext + }, + ReadHeaderTimeout: 2 * time.Second, // to mitigate a Slowloris attack + } + + go func() { + defer cleanupWaitGroup.Done() + <-ctx.Done() + + cancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Errorf("prometheus server (%s) failed to gracefully shut down", addr) + } + }() + + logrus.Infof("prometheus server listening on %s", addr) + + if err := server.ListenAndServe(); err != nil { + logrus.WithError(err).Errorf("prometheus server (%s) shut down", addr) + } else { + logrus.Info("prometheus metric exporter shut down") + } + }() + + return nil +} + +func enableOpenTelemetryMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + switch mc.ExporterProtocol { + case "grpc": + metricExporter, err := otlpmetricgrpc.New(ctx) + if err != nil { + return err + } + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExporter)), + ) + + otel.SetMeterProvider(meterProvider) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := metricExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to gracefully shut down OpenTelemetry metric exporter") + } else { + logrus.Info("OpenTelemetry metric exporter shut down") + } + }() + + case "http/protobuf": + metricExporter, err := otlpmetrichttp.New(ctx) + if err != nil { + return err + } + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExporter)), + ) + + otel.SetMeterProvider(meterProvider) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := metricExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to gracefully shut down OpenTelemetry metric exporter") + } else { + logrus.Info("OpenTelemetry metric exporter shut down") + } + }() + + default: // http/json for example + return fmt.Errorf("unsupported OpenTelemetry exporter protocol %q", mc.ExporterProtocol) + } + logrus.Info("OpenTelemetry metrics exporter started") + return nil + +} + +var ( + metricsOnce *sync.Once = &sync.Once{} +) + +func ConfigureMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + if ctx == nil { + panic("context must not be nil") + } + + var err error + + metricsOnce.Do(func() { + if mc.Enabled { + switch mc.Exporter { + case conf.Prometheus: + if err = enablePrometheusMetrics(ctx, mc); err != nil { + logrus.WithError(err).Error("unable to start prometheus metrics exporter") + return + } + + case conf.OpenTelemetryMetrics: + if err = enableOpenTelemetryMetrics(ctx, mc); err != nil { + logrus.WithError(err).Error("unable to start OTLP metrics exporter") + + return + } + } + } + + if err := otelruntimemetrics.Start(otelruntimemetrics.WithMinimumReadMemStatsInterval(time.Second)); err != nil { + logrus.WithError(err).Error("unable to start OpenTelemetry Go runtime metrics collection") + } else { + logrus.Info("Go runtime metrics collection started") + } + + meter := otel.Meter("gotrue") + _, err := meter.Int64ObservableGauge( + "gotrue_running", + metric.WithDescription("Whether GoTrue is running (always 1)"), + metric.WithInt64Callback(func(_ context.Context, obsrv metric.Int64Observer) error { + obsrv.Observe(int64(1)) + return nil + }), + ) + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.gotrue_running gague metric") + return + } + }) + + return err +} diff --git a/internal/observability/profiler.go b/internal/observability/profiler.go new file mode 100644 index 000000000..71acc1183 --- /dev/null +++ b/internal/observability/profiler.go @@ -0,0 +1,87 @@ +package observability + +import ( + "context" + "net" + "time" + + "net/http" + "net/http/pprof" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +func ConfigureProfiler(ctx context.Context, pc *conf.ProfilerConfig) error { + if !pc.Enabled { + return nil + } + addr := net.JoinHostPort(pc.Host, pc.Port) + baseContext, cancel := context.WithCancel(context.Background()) + cleanupWaitGroup.Add(1) + go func() { + server := &http.Server{ + Addr: addr, + Handler: &ProfilerHandler{}, + BaseContext: func(net.Listener) context.Context { + return baseContext + }, + ReadHeaderTimeout: 2 * time.Second, + } + + go func() { + defer cleanupWaitGroup.Done() + <-ctx.Done() + + cancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Errorf("profiler server (%s) failed to gracefully shut down", addr) + } + }() + + logrus.Infof("Profiler is listening on %s", addr) + + if err := server.ListenAndServe(); err != nil { + logrus.WithError(err).Errorf("profiler server (%s) shut down", addr) + } else { + logrus.Info("profiler shut down") + } + }() + + return nil +} + +type ProfilerHandler struct{} + +func (p *ProfilerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/debug/pprof/": + pprof.Index(w, r) + case "/debug/pprof/cmdline": + pprof.Cmdline(w, r) + case "/debug/pprof/profile": + pprof.Profile(w, r) + case "/debug/pprof/symbol": + pprof.Symbol(w, r) + case "/debug/pprof/trace": + pprof.Trace(w, r) + case "/debug/pprof/goroutine": + pprof.Handler("goroutine").ServeHTTP(w, r) + case "/debug/pprof/heap": + pprof.Handler("heap").ServeHTTP(w, r) + case "/debug/pprof/allocs": + pprof.Handler("allocs").ServeHTTP(w, r) + case "/debug/pprof/threadcreate": + pprof.Handler("threadcreate").ServeHTTP(w, r) + case "/debug/pprof/block": + pprof.Handler("block").ServeHTTP(w, r) + case "/debug/pprof/mutex": + pprof.Handler("mutex").ServeHTTP(w, r) + default: + http.NotFound(w, r) + } +} diff --git a/internal/observability/request-logger.go b/internal/observability/request-logger.go new file mode 100644 index 000000000..6eeffd6ea --- /dev/null +++ b/internal/observability/request-logger.go @@ -0,0 +1,114 @@ +package observability + +import ( + "fmt" + "net/http" + "time" + + chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +func AddRequestID(globalConfig *conf.GlobalConfiguration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + id := uuid.Must(uuid.NewV4()).String() + if globalConfig.API.RequestIDHeader != "" { + id = r.Header.Get(globalConfig.API.RequestIDHeader) + } + ctx := r.Context() + ctx = utilities.WithRequestID(ctx, id) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} + +func NewStructuredLogger(logger *logrus.Logger, config *conf.GlobalConfiguration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + next.ServeHTTP(w, r) + } else { + chimiddleware.RequestLogger(&structuredLogger{logger, config})(next).ServeHTTP(w, r) + } + }) + } +} + +type structuredLogger struct { + Logger *logrus.Logger + Config *conf.GlobalConfiguration +} + +func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { + referrer := utilities.GetReferrer(r, l.Config) + e := &logEntry{Entry: logrus.NewEntry(l.Logger)} + logFields := logrus.Fields{ + "component": "api", + "method": r.Method, + "path": r.URL.Path, + "remote_addr": utilities.GetIPAddress(r), + "referer": referrer, + } + + if reqID := utilities.GetRequestID(r.Context()); reqID != "" { + logFields["request_id"] = reqID + } + + e.Entry = e.Entry.WithFields(logFields) + return e +} + +// logEntry implements the chiMiddleware.LogEntry interface +type logEntry struct { + Entry *logrus.Entry +} + +func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + fields := logrus.Fields{ + "status": status, + "duration": elapsed.Nanoseconds(), + } + + errorCode := header.Get("x-sb-error-code") + if errorCode != "" { + fields["error_code"] = errorCode + } + + entry := e.Entry.WithFields(fields) + entry.Info("request completed") + e.Entry = entry +} + +func (e *logEntry) Panic(v interface{}, stack []byte) { + entry := e.Entry.WithFields(logrus.Fields{ + "stack": string(stack), + "panic": fmt.Sprintf("%+v", v), + }) + entry.Error("request panicked") + e.Entry = entry +} + +func GetLogEntry(r *http.Request) *logEntry { + l, _ := chimiddleware.GetLogEntry(r).(*logEntry) + if l == nil { + return &logEntry{Entry: logrus.NewEntry(logrus.StandardLogger())} + } + return l +} + +func LogEntrySetField(r *http.Request, key string, value interface{}) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithField(key, value) + } +} + +func LogEntrySetFields(r *http.Request, fields logrus.Fields) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithFields(fields) + } +} diff --git a/internal/observability/request-logger_test.go b/internal/observability/request-logger_test.go new file mode 100644 index 000000000..7ab244c3f --- /dev/null +++ b/internal/observability/request-logger_test.go @@ -0,0 +1,72 @@ +package observability + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +const apiTestConfig = "../../hack/test.env" + +func TestLogger(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + // add request id header + config.API.RequestIDHeader = "X-Request-ID" + addRequestIdHandler := AddRequestID(config) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com/path", nil) + req.Header.Add("X-Request-ID", "test-request-id") + require.NoError(t, err) + addRequestIdHandler(logHandler).ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "api", logs["component"]) + require.Equal(t, http.MethodPost, logs["method"]) + require.Equal(t, "/path", logs["path"]) + require.Equal(t, "test-request-id", logs["request_id"]) + require.NotNil(t, logs["time"]) +} + +func TestExcludeHealthFromLogs(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://example.com/health", nil) + require.NoError(t, err) + logHandler.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + require.Empty(t, logBuffer) +} diff --git a/internal/observability/request-tracing.go b/internal/observability/request-tracing.go new file mode 100644 index 000000000..e8ee61bc1 --- /dev/null +++ b/internal/observability/request-tracing.go @@ -0,0 +1,170 @@ +package observability + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.25.0" + "go.opentelemetry.io/otel/trace" +) + +// traceChiRoutesSafely attempts to extract the Chi RouteContext. If the +// request does not have a RouteContext it will recover from the panic and +// attempt to figure out the route from the URL's path. +func traceChiRoutesSafely(r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to trace chi routes, traces may be off") + + span := trace.SpanFromContext(r.Context()) + span.SetAttributes(semconv.HTTPRouteKey.String(r.URL.Path)) + } + }() + + routeContext := chi.RouteContext(r.Context()) + span := trace.SpanFromContext(r.Context()) + span.SetAttributes(semconv.HTTPRouteKey.String(routeContext.RoutePattern())) +} + +// traceChiRouteURLParamsSafely attempts to extract the Chi RouteContext +// URLParams values for the route and assign them to the tracing span. If the +// request does not have a RouteContext it will recover from the panic and not +// set any params. +func traceChiRouteURLParamsSafely(r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to trace route with route params, traces may be off") + } + }() + + routeContext := chi.RouteContext(r.Context()) + span := trace.SpanFromContext(r.Context()) + + var attributes []attribute.KeyValue + + for i := 0; i < len(routeContext.URLParams.Keys); i += 1 { + key := routeContext.URLParams.Keys[i] + value := routeContext.URLParams.Values[i] + + attributes = append(attributes, attribute.String("http.route.param."+key, value)) + } + + if len(attributes) > 0 { + span.SetAttributes(attributes...) + } +} + +type interceptingResponseWriter struct { + writer http.ResponseWriter + + statusCode int +} + +func (w *interceptingResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + + w.writer.WriteHeader(statusCode) +} + +func (w *interceptingResponseWriter) Write(data []byte) (int, error) { + return w.writer.Write(data) +} + +func (w *interceptingResponseWriter) Header() http.Header { + return w.writer.Header() +} + +// countStatusCodesSafely counts the number of HTTP status codes per route that +// occurred while GoTrue was running. If it is not able to identify the route +// via chi.RouteContext(ctx).RoutePattern() it counts with a noroute attribute. +func countStatusCodesSafely(w *interceptingResponseWriter, r *http.Request, counter metric.Int64Counter) { + if counter == nil { + return + } + + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to count status codes safely, metrics may be off") + counter.Add( + r.Context(), + 1, + metric.WithAttributes( + attribute.Bool("noroute", true), + attribute.Int("code", w.statusCode)), + ) + } + }() + + ctx := r.Context() + + routeContext := chi.RouteContext(ctx) + routePattern := semconv.HTTPRouteKey.String(routeContext.RoutePattern()) + + counter.Add( + ctx, + 1, + metric.WithAttributes(attribute.Int("code", w.statusCode), routePattern), + ) +} + +// RequestTracing returns an HTTP handler that traces all HTTP requests coming +// in. Supports Chi routers, so this should be one of the first middlewares on +// the router. +func RequestTracing() func(http.Handler) http.Handler { + meter := otel.Meter("gotrue") + statusCodes, err := meter.Int64Counter( + "http_status_codes", + metric.WithDescription("Number of returned HTTP status codes"), + ) + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.http_status_codes counter metric") + } + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + writer := interceptingResponseWriter{ + writer: w, + } + + defer traceChiRoutesSafely(r) + defer traceChiRouteURLParamsSafely(r) + defer countStatusCodesSafely(&writer, r, statusCodes) + + originalUserAgent := r.Header.Get("X-Gotrue-Original-User-Agent") + if originalUserAgent != "" { + r.Header.Set("User-Agent", originalUserAgent) + } + + next.ServeHTTP(&writer, r) + + if originalUserAgent != "" { + r.Header.Set("X-Gotrue-Original-User-Agent", originalUserAgent) + r.Header.Set("User-Agent", "stripped") + } + } + + otelHandler := otelhttp.NewHandler(http.HandlerFunc(fn), "api") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // there is a vulnerability with otelhttp where + // User-Agent strings are kept in RAM indefinitely and + // can be used as an easy way to resource exhaustion; + // so this code strips the User-Agent header before + // it's passed to be traced by otelhttp, and then is + // returned back to the middleware + // https://github.com/supabase/gotrue/security/dependabot/11 + userAgent := r.UserAgent() + if userAgent != "" { + r.Header.Set("X-Gotrue-Original-User-Agent", userAgent) + r.Header.Set("User-Agent", "stripped") + } + + otelHandler.ServeHTTP(w, r) + }) + } +} diff --git a/internal/observability/tracing.go b/internal/observability/tracing.go new file mode 100644 index 000000000..cc1847131 --- /dev/null +++ b/internal/observability/tracing.go @@ -0,0 +1,130 @@ +package observability + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/propagation" + sdkresource "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +func Tracer(name string, opts ...trace.TracerOption) trace.Tracer { + return otel.Tracer(name, opts...) +} + +func openTelemetryResource() *sdkresource.Resource { + environmentResource := sdkresource.Environment() + gotrueResource := sdkresource.NewSchemaless(attribute.String("gotrue.version", utilities.Version)) + + mergedResource, err := sdkresource.Merge(environmentResource, gotrueResource) + if err != nil { + logrus.WithError(err).Error("unable to merge OpenTelemetry environment and gotrue resources") + + return environmentResource + } + + return mergedResource +} + +func enableOpenTelemetryTracing(ctx context.Context, tc *conf.TracingConfig) error { + var ( + err error + traceExporter *otlptrace.Exporter + ) + + switch tc.ExporterProtocol { + case "grpc": + traceExporter, err = otlptracegrpc.New(ctx) + if err != nil { + return err + } + + case "http/protobuf": + traceExporter, err = otlptracehttp.New(ctx) + if err != nil { + return err + } + + default: // http/json for example + return fmt.Errorf("unsupported OpenTelemetry exporter protocol %q", tc.ExporterProtocol) + } + + traceProvider := sdktrace.NewTracerProvider( + sdktrace.WithBatcher(traceExporter), + sdktrace.WithResource(openTelemetryResource()), + ) + + otel.SetTracerProvider(traceProvider) + + // Register the W3C trace context and baggage propagators so data is + // propagated across services/processes + otel.SetTextMapPropagator( + propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ), + ) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := traceExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to shutdown OpenTelemetry trace exporter") + } + + if err := traceProvider.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to shutdown OpenTelemetry trace provider") + } + }() + + logrus.Info("OpenTelemetry trace exporter started") + + return nil +} + +var ( + tracingOnce sync.Once +) + +// ConfigureTracing sets up global tracing configuration for OpenTracing / +// OpenTelemetry. The context should be the global context. Cancelling this +// context will cancel tracing collection. +func ConfigureTracing(ctx context.Context, tc *conf.TracingConfig) error { + if ctx == nil { + panic("context must not be nil") + } + + var err error + + tracingOnce.Do(func() { + if tc.Enabled { + if tc.Exporter == conf.OpenTelemetryTracing { + if err = enableOpenTelemetryTracing(ctx, tc); err != nil { + logrus.WithError(err).Error("unable to start OTLP trace exporter") + } + + } + } + }) + + return err +} diff --git a/internal/ratelimit/burst.go b/internal/ratelimit/burst.go new file mode 100644 index 000000000..6ae0ef58b --- /dev/null +++ b/internal/ratelimit/burst.go @@ -0,0 +1,60 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/time/rate" +) + +const defaultOverTime = time.Hour + +// BurstLimiter wraps the golang.org/x/time/rate package. +type BurstLimiter struct { + rl *rate.Limiter +} + +// NewBurstLimiter returns a rate limiter configured using the given conf.Rate. +// +// The returned Limiter will be configured with a token bucket containing a +// single token, which will fill up at a rate of 1 event per r.OverTime with +// an initial burst amount of r.Events. +// +// For example: +// - 1/10s is 1 events per 10 seconds with burst of 1. +// - 1/2s is 1 events per 2 seconds with burst of 1. +// - 10/10s is 1 events per 10 seconds with burst of 10. +// +// If Rate.Events is <= 0, the burst amount will be set to 1. +// +// See Example_newBurstLimiter for a visualization. +func NewBurstLimiter(r conf.Rate) *BurstLimiter { + // The rate limiter deals in events per second. + d := r.OverTime + if d <= 0 { + d = defaultOverTime + } + + e := r.Events + if e <= 0 { + e = 0 + } + + // BurstLimiter will have an initial token bucket of size `e`. It will + // be refilled at a rate of 1 per duration `d` indefinitely. + rl := &BurstLimiter{ + rl: rate.NewLimiter(rate.Every(d), int(e)), + } + return rl +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (l *BurstLimiter) Allow() bool { + return l.AllowAt(time.Now()) +} + +// AllowAt implements Limiter by calling the underlying x/time/rate.Limiter +// with the given time. +func (l *BurstLimiter) AllowAt(at time.Time) bool { + return l.rl.AllowN(at, 1) +} diff --git a/internal/ratelimit/burst_test.go b/internal/ratelimit/burst_test.go new file mode 100644 index 000000000..b854e3b27 --- /dev/null +++ b/internal/ratelimit/burst_test.go @@ -0,0 +1,214 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newBurstLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + { + cfg := conf.Rate{Events: 10, OverTime: time.Second * 20} + rl := NewBurstLimiter(cfg) + cur := now + for i := 0; i < 20; i++ { + allowed := rl.AllowAt(cur) + fmt.Printf("%-5v @ %v\n", allowed, cur) + cur = cur.Add(time.Second * 5) + } + } + + // Output: + // true @ 2024-09-24 10:00:00 +0000 UTC + // true @ 2024-09-24 10:00:05 +0000 UTC + // true @ 2024-09-24 10:00:10 +0000 UTC + // true @ 2024-09-24 10:00:15 +0000 UTC + // true @ 2024-09-24 10:00:20 +0000 UTC + // true @ 2024-09-24 10:00:25 +0000 UTC + // true @ 2024-09-24 10:00:30 +0000 UTC + // true @ 2024-09-24 10:00:35 +0000 UTC + // true @ 2024-09-24 10:00:40 +0000 UTC + // true @ 2024-09-24 10:00:45 +0000 UTC + // true @ 2024-09-24 10:00:50 +0000 UTC + // true @ 2024-09-24 10:00:55 +0000 UTC + // true @ 2024-09-24 10:01:00 +0000 UTC + // false @ 2024-09-24 10:01:05 +0000 UTC + // false @ 2024-09-24 10:01:10 +0000 UTC + // false @ 2024-09-24 10:01:15 +0000 UTC + // true @ 2024-09-24 10:01:20 +0000 UTC + // false @ 2024-09-24 10:01:25 +0000 UTC + // false @ 2024-09-24 10:01:30 +0000 UTC + // false @ 2024-09-24 10:01:35 +0000 UTC +} + +func TestBurstLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) + + t.Run("AllowAt", func(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + + type event struct { + ok bool + at time.Time + + // Event should be `ok` at `at` for `i` times + i int + } + + type testCase struct { + cfg conf.Rate + now time.Time + evts []event + } + cases := []testCase{ + { + cfg: conf.Rate{Events: 20, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 20 is permitted + {true, now, 19}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + + // allow tokens to be built up still + {true, now.Add(time.Hour), 19}, + }, + }, + + { + cfg: conf.Rate{Events: 1, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 1 is permitted + {true, now, 0}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + }, + }, + + // 1 event per second + { + cfg: conf.Rate{Events: 1, OverTime: time.Second}, + now: now, + evts: []event{ + {true, now, 0}, + {true, now.Add(time.Second), 0}, + {false, now.Add(time.Second), 0}, + {true, now.Add(time.Second * 2), 0}, + }, + }, + + // 1 events per second and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 1, OverTime: 0}, + now: now, + evts: []event{ + {true, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {true, now.Add(time.Hour), 0}, + {true, now.Add(time.Hour * 2), 0}, + }, + }, + + // zero value for Events = 0 event per second + { + cfg: conf.Rate{Events: 0, OverTime: time.Second}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(-time.Second), 0}, + {false, now.Add(time.Second), 0}, + {false, now.Add(time.Second * 2), 0}, + }, + }, + + // zero value for both Events and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 0, OverTime: 0}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {false, now.Add(-time.Hour), 0}, + {false, now.Add(time.Hour), 0}, + {false, now.Add(time.Hour * 2), 0}, + }, + }, + } + + for _, tc := range cases { + rl := NewBurstLimiter(tc.cfg) + for _, evt := range tc.evts { + for i := 0; i <= evt.i; i++ { + if exp, got := evt.ok, rl.AllowAt(evt.at); exp != got { + t.Fatalf("exp AllowAt(%v) to be %v; got %v", evt.at, exp, got) + } + } + } + } + }) +} diff --git a/internal/ratelimit/interval.go b/internal/ratelimit/interval.go new file mode 100644 index 000000000..a72302f74 --- /dev/null +++ b/internal/ratelimit/interval.go @@ -0,0 +1,63 @@ +package ratelimit + +import ( + "sync" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// IntervalLimiter will limit the number of calls to Allow per interval. +type IntervalLimiter struct { + mu sync.Mutex + ival time.Duration // Count is reset and time updated every ival. + limit int // Limit calls to Allow() per ival. + + // Guarded by mu. + last time.Time // When the limiter was last reset. + count int // Total calls to Allow() since time. +} + +// NewIntervalLimiter returns a rate limiter using the given conf.Rate. +func NewIntervalLimiter(r conf.Rate) *IntervalLimiter { + return &IntervalLimiter{ + ival: r.OverTime, + limit: int(r.Events), + last: time.Now(), + } +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (rl *IntervalLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(time.Now()) +} + +// AllowAt implements Limiter by checking if the current number of permitted +// events within this interval would permit 1 additional event at the current +// time. +// +// When called with a time outside the current active interval the counter is +// reset, meaning it can be vulnerable at the edge of it's intervals so avoid +// small intervals. +func (rl *IntervalLimiter) AllowAt(at time.Time) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(at) +} + +func (rl *IntervalLimiter) allowAt(at time.Time) bool { + since := at.Sub(rl.last) + if ivals := int64(since / rl.ival); ivals > 0 { + rl.last = rl.last.Add(time.Duration(ivals) * rl.ival) + rl.count = 0 + } + if rl.count < rl.limit { + rl.count++ + return true + } + return false +} diff --git a/internal/ratelimit/interval_test.go b/internal/ratelimit/interval_test.go new file mode 100644 index 000000000..835ee8256 --- /dev/null +++ b/internal/ratelimit/interval_test.go @@ -0,0 +1,81 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newIntervalLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} + rl := NewIntervalLimiter(cfg) + rl.last = now + + cur := now + allowed := 0 + + for days := 0; days < 2; days++ { + // First 100 events succeed. + for i := 0; i < 100; i++ { + allow := rl.allowAt(cur) + cur = cur.Add(time.Second) + + if !allow { + fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed) + return + } + allowed++ + } + fmt.Printf("true @ %v for last %v events...\n", cur, allowed) + + // We try hourly until it allows us to make requests again. + denied := 0 + for i := 0; i < 23; i++ { + cur = cur.Add(time.Hour) + allow := rl.AllowAt(cur) + if allow { + fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur) + return + } + denied++ + } + fmt.Printf("false @ %v for last %v events...\n", cur, denied) + + cur = cur.Add(time.Hour) + } + + // Output: + // true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events... + // false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events... + // true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events... + // false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events... +} + +func TestNewIntervalLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewIntervalLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + + // should accept a negative burst. + cfg := conf.Rate{Events: 10, OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := 0; y < 10; y++ { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 000000000..35fbf9bc4 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,34 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" +) + +// Limiter is the interface implemented by rate limiters. +// +// Implementations of Limiter must be safe for concurrent use. +type Limiter interface { + + // Allow should return true if an event should be allowed at the time + // which it was called, or false otherwise. + Allow() bool + + // AllowAt should return true if an event should be allowed at the given + // time, or false otherwise. + AllowAt(at time.Time) bool +} + +// New returns a new Limiter based on the given config. +// +// When the type is conf.BurstRateType it returns a BurstLimiter, otherwise +// New returns an IntervalLimiter. +func New(r conf.Rate) Limiter { + switch r.GetRateType() { + case conf.BurstRateType: + return NewBurstLimiter(r) + default: + return NewIntervalLimiter(r) + } +} diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 000000000..3bac1dca2 --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,50 @@ +package ratelimit + +import ( + "testing" + + "github.com/supabase/auth/internal/conf" +) + +func TestNew(t *testing.T) { + + // IntervalLimiter + { + var r conf.Rate + err := r.Decode("100") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + { + var r conf.Rate + err := r.Decode("100.123") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + + // BurstLimiter + { + var r conf.Rate + err := r.Decode("20/200s") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*BurstLimiter); !ok { + t.Fatalf("exp type *BurstLimiter; got %T", rl) + } + } +} diff --git a/internal/reloader/handler.go b/internal/reloader/handler.go new file mode 100644 index 000000000..bdd15ca88 --- /dev/null +++ b/internal/reloader/handler.go @@ -0,0 +1,42 @@ +package reloader + +import ( + "net/http" + "sync/atomic" +) + +// AtomicHandler provides an atomic http.Handler implementation, allowing safe +// handler replacement at runtime. AtomicHandler must be initialized with a call +// to NewAtomicHandler. It will never panic and is safe for concurrent use. +type AtomicHandler struct { + val atomic.Value +} + +// atomicHandlerValue is the value stored within an atomicHandler. +type atomicHandlerValue struct{ http.Handler } + +// NewAtomicHandler creates a new AtomicHandler ready for use. +func NewAtomicHandler(h http.Handler) *AtomicHandler { + ah := new(AtomicHandler) + ah.Store(h) + return ah +} + +// String implements fmt.Stringer by returning a string literal. +func (ah *AtomicHandler) String() string { return "reloader.AtomicHandler" } + +// Store will update this http.Handler to serve future requests using h. +func (ah *AtomicHandler) Store(h http.Handler) { + ah.val.Store(&atomicHandlerValue{h}) +} + +// load will return the underlying http.Handler used to serve requests. +func (ah *AtomicHandler) load() http.Handler { + return ah.val.Load().(*atomicHandlerValue).Handler +} + +// ServeHTTP implements the standard libraries http.Handler interface by +// atomically passing the request along to the most recently stored handler. +func (ah *AtomicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ah.load().ServeHTTP(w, r) +} diff --git a/internal/reloader/handler_race_test.go b/internal/reloader/handler_race_test.go new file mode 100644 index 000000000..fd0236e1a --- /dev/null +++ b/internal/reloader/handler_race_test.go @@ -0,0 +1,67 @@ +//go:build race +// +build race + +package reloader + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicHandlerRaces(t *testing.T) { + type testHandler struct{ http.Handler } + + hrFn := func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + } + + const count = 8 + hrFuncMap := make(map[http.Handler]struct{}, count) + for i := 0; i < count; i++ { + hrFuncMap[&testHandler{hrFn()}] = struct{}{} + } + + hr := NewAtomicHandler(nil) + assert.NotNil(t, hr) + + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second/4) + defer cancel() + + // We create 8 goroutines reading & writing to the handler concurrently. If + // a race condition occurs the test will fail and halt. + for i := 0; i < count; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + for hrFunc := range hrFuncMap { + select { + case <-ctx.Done(): + default: + } + + hr.Store(hrFunc) + + // Calling string should be safe + hr.String() + + got := hr.load() + _, ok := hrFuncMap[got] + if !ok { + // This will trigger a race failure / exit test + t.Fatal("unknown handler returned from load()") + return + } + } + }() + } + wg.Wait() +} diff --git a/internal/reloader/handler_test.go b/internal/reloader/handler_test.go new file mode 100644 index 000000000..b3acc21da --- /dev/null +++ b/internal/reloader/handler_test.go @@ -0,0 +1,55 @@ +package reloader + +import ( + "net/http" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicHandler(t *testing.T) { + // for ptr identity + type testHandler struct{ http.Handler } + + var calls atomic.Int64 + hrFn := func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + }) + } + + hrFunc1 := &testHandler{hrFn()} + hrFunc2 := &testHandler{hrFn()} + assert.NotEqual(t, hrFunc1, hrFunc2) + + // a new AtomicHandler should be non-nil + hr := NewAtomicHandler(hrFunc1) + assert.NotNil(t, hr) + assert.Equal(t, "reloader.AtomicHandler", hr.String()) + + // should implement http.Handler + { + v := (http.Handler)(hr) + before := calls.Load() + v.ServeHTTP(nil, nil) + after := calls.Load() + if exp, got := before+1, after; exp != got { + t.Fatalf("exp %v to be %v after handler was called", got, exp) + } + } + + // should be non-nil after store + for i := 0; i < 3; i++ { + hr.Store(hrFunc1) + assert.NotNil(t, hr.load()) + assert.Equal(t, hr.load(), hrFunc1) + assert.Equal(t, hr.load() == hrFunc1, true) + + // should update to hrFunc2 + hr.Store(hrFunc2) + assert.NotNil(t, hr.load()) + assert.Equal(t, hr.load(), hrFunc2) + assert.Equal(t, hr.load() == hrFunc2, true) + } +} diff --git a/internal/reloader/reloader.go b/internal/reloader/reloader.go new file mode 100644 index 000000000..c518ae374 --- /dev/null +++ b/internal/reloader/reloader.go @@ -0,0 +1,242 @@ +// Package reloader provides support for live configuration reloading. +package reloader + +import ( + "context" + "errors" + "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +const ( + // reloadInterval is the interval between configuration reloading. At most + // one configuration change may be made between this duration. + reloadInterval = time.Second * 10 + + // tickerInterval is the maximum latency between configuration reloads. + tickerInterval = reloadInterval / 10 +) + +type ConfigFunc func(*conf.GlobalConfiguration) + +type Reloader struct { + watchDir string + reloadIval time.Duration + tickerIval time.Duration + watchFn func() (watcher, error) + reloadFn func(dir string) (*conf.GlobalConfiguration, error) + addDirFn func(ctx context.Context, wr watcher, dir string, dur time.Duration) error +} + +func NewReloader(watchDir string) *Reloader { + return &Reloader{ + watchDir: watchDir, + reloadIval: reloadInterval, + tickerIval: tickerInterval, + watchFn: newFSWatcher, + reloadFn: defaultReloadFn, + addDirFn: defaultAddDirFn, + } +} + +// reload attempts to create a new *conf.GlobalConfiguration after loading the +// currently configured watchDir. +func (rl *Reloader) reload() (*conf.GlobalConfiguration, error) { + return rl.reloadFn(rl.watchDir) +} + +// reloadCheckAt checks if reloadConfig should be called, returns true if config +// should be reloaded or false otherwise. +func (rl *Reloader) reloadCheckAt(at, lastUpdate time.Time) bool { + if lastUpdate.IsZero() { + return false // no pending updates + } + if at.Sub(lastUpdate) < rl.reloadIval { + return false // waiting for reload interval + } + + // Update is pending. + return true +} + +func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error { + wr, err := rl.watchFn() + if err != nil { + logrus.WithError(err).Error("reloader: error creating fsnotify Watcher") + return err + } + defer wr.Close() + + tr := time.NewTicker(rl.tickerIval) + defer tr.Stop() + + // Ignore errors, if watch dir doesn't exist we can add it later. + if err := rl.addDirFn(ctx, wr, rl.watchDir, reloadInterval); err != nil { + logrus.WithError(err).Error("reloader: error watching config directory") + } + + var lastUpdate time.Time + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case <-tr.C: + // This is a simple way to solve watch dir being added later or + // being moved and then recreated. I've tested all of these basic + // scenarios and wr.WatchList() does not grow which aligns with + // the documented behavior. + if err := rl.addDirFn(ctx, wr, rl.watchDir, reloadInterval); err != nil { + logrus.WithError(err).Error("reloader: error watching config directory") + } + + // Check to see if the config is ready to be relaoded. + if !rl.reloadCheckAt(time.Now(), lastUpdate) { + continue + } + + // Reset the last update time before we try to reload the config. + lastUpdate = time.Time{} + + cfg, err := rl.reload() + if err != nil { + logrus.WithError(err).Error("reloader: error loading config") + continue + } + + // Call the callback function with the latest cfg. + fn(cfg) + + case evt, ok := <-wr.Events(): + if !ok { + err := errors.New("reloader: fsnotify event channel was closed") + logrus.WithError(err).Error(err) + return err + } + + // We only read files ending in .env + if !strings.HasSuffix(evt.Name, ".env") { + continue + } + + switch { + case evt.Op.Has(fsnotify.Create), + evt.Op.Has(fsnotify.Remove), + evt.Op.Has(fsnotify.Rename), + evt.Op.Has(fsnotify.Write): + lastUpdate = time.Now() + } + case err, ok := <-wr.Errors(): + if !ok { + err := errors.New("reloader: fsnotify error channel was closed") + logrus.WithError(err).Error(err) + return err + } + logrus.WithError(err).Error( + "reloader: fsnotify has reported an error") + } + } +} + +// defaultAddDirFn adds a dir to a watcher with a common error and sleep +// duration if the directory doesn't exist. +func defaultAddDirFn(ctx context.Context, wr watcher, dir string, dur time.Duration) error { + if err := wr.Add(dir); err != nil { + tr := time.NewTicker(dur) + defer tr.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-tr.C: + return err + } + } + return nil +} + +func defaultReloadFn(dir string) (*conf.GlobalConfiguration, error) { + if err := conf.LoadDirectory(dir); err != nil { + return nil, err + } + + cfg, err := conf.LoadGlobalFromEnv() + if err != nil { + return nil, err + } + return cfg, nil +} + +type watcher interface { + Add(path string) error + Close() error + Events() chan fsnotify.Event + Errors() chan error +} + +type fsNotifyWatcher struct { + wr *fsnotify.Watcher +} + +func newFSWatcher() (watcher, error) { + wr, err := fsnotify.NewWatcher() + return &fsNotifyWatcher{wr}, err +} + +func (o *fsNotifyWatcher) Add(path string) error { return o.wr.Add(path) } +func (o *fsNotifyWatcher) Close() error { return o.wr.Close() } +func (o *fsNotifyWatcher) Errors() chan error { return o.wr.Errors } +func (o *fsNotifyWatcher) Events() chan fsnotify.Event { return o.wr.Events } + +type mockWatcher struct { + mu sync.Mutex + err error + eventCh chan fsnotify.Event + errorCh chan error + addCh chan string +} + +func newMockWatcher(err error) *mockWatcher { + wr := &mockWatcher{ + err: err, + eventCh: make(chan fsnotify.Event, 1024), + errorCh: make(chan error, 1024), + addCh: make(chan string, 1024), + } + return wr +} + +func (o *mockWatcher) getErr() error { + o.mu.Lock() + defer o.mu.Unlock() + err := o.err + return err +} + +func (o *mockWatcher) setErr(err error) { + o.mu.Lock() + defer o.mu.Unlock() + o.err = err +} + +func (o *mockWatcher) Add(path string) error { + o.mu.Lock() + defer o.mu.Unlock() + if err := o.err; err != nil { + return err + } + + select { + case o.addCh <- path: + default: + } + return nil +} +func (o *mockWatcher) Close() error { return o.getErr() } +func (o *mockWatcher) Events() chan fsnotify.Event { return o.eventCh } +func (o *mockWatcher) Errors() chan error { return o.errorCh } diff --git a/internal/reloader/reloader_test.go b/internal/reloader/reloader_test.go new file mode 100644 index 000000000..aeaea5a22 --- /dev/null +++ b/internal/reloader/reloader_test.go @@ -0,0 +1,515 @@ +package reloader + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "path" + "path/filepath" + "testing" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/supabase/auth/internal/conf" + "golang.org/x/sync/errgroup" +) + +func TestWatch(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + dir, cleanup := helpTestDir(t) + defer cleanup() + + // test broken watcher + { + sentinelErr := errors.New("sentinel") + rr := mockReloadRecorder() + rl := NewReloader(dir) + rl.watchFn = func() (watcher, error) { return nil, sentinelErr } + + err := rl.Watch(ctx, rr.configFn) + if exp, got := sentinelErr, err; exp != got { + assert.Equal(t, exp, got) + } + } + + // test watch invalid dir + { + doneCtx, doneCancel := context.WithCancel(ctx) + doneCancel() + + rr := mockReloadRecorder() + rl := NewReloader(path.Join(dir, "__not_found__")) + err := rl.Watch(doneCtx, rr.configFn) + if exp, got := context.Canceled, err; exp != got { + assert.Equal(t, exp, got) + } + } + + // test watch invalid dir in addDirFn + { + + sentinel := errors.New("sentinel") + wr := newMockWatcher(sentinel) + rl := NewReloader(path.Join(dir, "__not_found__")) + rl.watchFn = func() (watcher, error) { return wr, nil } + + err := rl.addDirFn(ctx, wr, "__not_found__", time.Millisecond) + if exp, got := sentinel, err; exp != got { + assert.Equal(t, exp, got) + } + } + + // test watch error chan closed + { + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + wr.errorCh <- errors.New("sentinel") + close(wr.errorCh) + + rl := NewReloader(dir) + rl.watchFn = func() (watcher, error) { return wr, nil } + + err := rl.Watch(ctx, rr.configFn) + assert.NotNil(t, err) + + msg := "reloader: fsnotify error channel was closed" + if exp, got := msg, err.Error(); exp != got { + assert.Equal(t, exp, got) + } + } + + // test watch event chan closed + { + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + close(wr.eventCh) + + rl := NewReloader(dir) + rl.reloadIval = time.Second / 100 + rl.watchFn = func() (watcher, error) { return wr, nil } + + err := rl.Watch(ctx, rr.configFn) + if err == nil { + assert.NotNil(t, err) + } + + msg := "reloader: fsnotify event channel was closed" + if exp, got := msg, err.Error(); exp != got { + assert.Equal(t, exp, got) + } + } + + // test watch error chan + { + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + wr.errorCh <- errors.New("sentinel") + + rl := NewReloader(dir) + rl.watchFn = func() (watcher, error) { return wr, nil } + + egCtx, egCancel := context.WithCancel(ctx) + defer egCancel() + + var eg errgroup.Group + eg.Go(func() error { + return rl.Watch(egCtx, rr.configFn) + }) + + // need to ensure errorCh drains so test isn't racey + eg.Go(func() error { + defer egCancel() + + tr := time.NewTicker(time.Second / 100) + defer tr.Stop() + + for { + select { + case <-egCtx.Done(): + return egCtx.Err() + case <-tr.C: + if len(wr.errorCh) == 0 { + return nil + } + } + } + }) + + err := eg.Wait() + if exp, got := context.Canceled, err; exp != got { + assert.Equal(t, exp, got) + } + } + + // test an end to end config reload + { + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + rl := NewReloader(dir) + rl.watchFn = func() (watcher, error) { return wr, wr.getErr() } + rl.reloadFn = rr.reloadFn + rl.addDirFn = func(ctx context.Context, wr watcher, dir string, dur time.Duration) error { + if err := wr.Add(dir); err != nil { + logrus.WithError(err).Error("reloader: error watching config directory") + return err + } + return nil + } + + // Need to lower reload ival to pickup config write quicker. + rl.reloadIval = time.Second / 10 + rl.tickerIval = rl.reloadIval / 10 + + egCtx, egCancel := context.WithCancel(ctx) + defer egCancel() + + var eg errgroup.Group + eg.Go(func() error { + return rl.Watch(egCtx, rr.configFn) + }) + + // Copy a full and valid example configuration to trigger Watch + { + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case v := <-wr.addCh: + assert.Equal(t, v, dir) + } + + name := helpCopyEnvFile(t, dir, "01_example.env", "testdata/50_example.env") + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case cfg := <-rr.configCh: + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, false) + } + } + + { + drain(rr.configCh) + drain(rr.reloadCh) + + name := helpWriteEnvFile(t, dir, "02_example.env", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "true", + }) + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case cfg := <-rr.configCh: + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, true) + } + } + + { + name := helpWriteEnvFile(t, dir, "03_example.env.bak", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "false", + }) + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + } + + { + // empty the reload ch + drain(rr.reloadCh) + + name := helpWriteEnvFile(t, dir, "04_example.env", map[string]string{ + "GOTRUE_SMTP_PORT": "ABC", + }) + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case p := <-rr.reloadCh: + if exp, got := dir, p; exp != got { + assert.Equal(t, exp, got) + } + } + } + + { + name := helpWriteEnvFile(t, dir, "05_example.env", map[string]string{ + "GOTRUE_SMTP_PORT": "2222", + }) + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case cfg := <-rr.configCh: + assert.NotNil(t, cfg) + assert.Equal(t, cfg.SMTP.Port, 2222) + } + } + + // test the wr.Add doesn't exit if bad watch dir is given during tick + { + // set the error on watcher + sentinelErr := errors.New("sentinel") + wr.setErr(sentinelErr) + + name := helpWriteEnvFile(t, dir, "05_example.env", map[string]string{ + "GOTRUE_SMTP_PORT": "2222", + }) + wr.eventCh <- fsnotify.Event{ + Name: name, + Op: fsnotify.Create, + } + + select { + case <-egCtx.Done(): + assert.Nil(t, egCtx.Err()) + case cfg := <-rr.configCh: + assert.NotNil(t, cfg) + assert.Equal(t, cfg.SMTP.Port, 2222) + } + } + + // test cases ran, end context to unblock Wait() + egCancel() + + err := eg.Wait() + if exp, got := context.Canceled, err; exp != got { + assert.Equal(t, exp, got) + } + } +} + +func TestReloadConfig(t *testing.T) { + dir, cleanup := helpTestDir(t) + defer cleanup() + + rl := NewReloader(dir) + + // Copy the full and valid example configuration. + helpCopyEnvFile(t, dir, "01_example.env", "testdata/50_example.env") + { + cfg, err := rl.reload() + assert.Nil(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, false) + } + + helpWriteEnvFile(t, dir, "02_example.env", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "true", + }) + { + cfg, err := rl.reload() + assert.Nil(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, true) + } + + helpWriteEnvFile(t, dir, "03_example.env.bak", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "false", + }) + { + cfg, err := rl.reload() + assert.Nil(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, true) + } + + // test cfg reload failure + helpWriteEnvFile(t, dir, "04_example.env", map[string]string{ + "PORT": "INVALIDPORT", + "GOTRUE_SMTP_PORT": "ABC", + }) + { + cfg, err := rl.reload() + assert.NotNil(t, err) + assert.Nil(t, cfg) + } + + // test directory loading failure + { + cleanup() + + cfg, err := rl.reload() + assert.NotNil(t, err) + assert.Nil(t, cfg) + } +} + +func TestReloadCheckAt(t *testing.T) { + const s10 = time.Second * 10 + + now := time.Now() + tests := []struct { + rl *Reloader + at, lastUpdate time.Time + exp bool + }{ + // no lastUpdate is set (time.IsZero()) + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + exp: false, + }, + + // last update within reload interval + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10 + 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now, + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 - 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 + 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 * 2), + exp: false, + }, + + // last update was outside our reload interval + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10), + exp: true, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10 - 1), + exp: true, + }, + } + for _, tc := range tests { + rl := tc.rl + assert.NotNil(t, rl) + assert.Equal(t, rl.reloadCheckAt(tc.at, tc.lastUpdate), tc.exp) + } +} + +func helpTestDir(t testing.TB) (dir string, cleanup func()) { + name := fmt.Sprintf("%v_%v", t.Name(), time.Now().Nanosecond()) + dir = filepath.Join("testdata", name) + err := os.MkdirAll(dir, 0750) + if err != nil && !os.IsExist(err) { + assert.Nil(t, err) + } + return dir, func() { os.RemoveAll(dir) } +} + +func helpCopyEnvFile(t testing.TB, dir, name, src string) string { + data, err := os.ReadFile(src) // #nosec G304 + if err != nil { + assert.Nil(t, err) + } + + dst := filepath.Join(dir, name) + err = os.WriteFile(dst, data, 0600) + if err != nil { + assert.Nil(t, err) + } + return dst +} + +func helpWriteEnvFile(t testing.TB, dir, name string, values map[string]string) string { + var buf bytes.Buffer + for k, v := range values { + buf.WriteString(k) + buf.WriteString("=") + buf.WriteString(v) + buf.WriteString("\n") + } + + dst := filepath.Join(dir, name) + err := os.WriteFile(dst, buf.Bytes(), 0600) + assert.Nil(t, err) + return dst +} + +func mockReloadRecorder() *reloadRecorder { + rr := &reloadRecorder{ + configCh: make(chan *conf.GlobalConfiguration, 1024), + reloadCh: make(chan string, 1024), + } + return rr +} + +func drain[C ~chan T, T any](ch C) (out []T) { + for { + select { + case v := <-ch: + out = append(out, v) + default: + return out + } + } +} + +type reloadRecorder struct { + configCh chan *conf.GlobalConfiguration + reloadCh chan string +} + +func (o *reloadRecorder) reloadFn(dir string) (*conf.GlobalConfiguration, error) { + defer func() { + select { + case o.reloadCh <- dir: + default: + } + }() + return defaultReloadFn(dir) +} + +func (o *reloadRecorder) configFn(gc *conf.GlobalConfiguration) { + select { + case o.configCh <- gc: + default: + } +} diff --git a/internal/reloader/testdata/50_example.env b/internal/reloader/testdata/50_example.env new file mode 100644 index 000000000..8ec6039f6 --- /dev/null +++ b/internal/reloader/testdata/50_example.env @@ -0,0 +1,239 @@ +# General Config +# NOTE: The service_role key is required as an authorization header for /admin endpoints + +GOTRUE_JWT_SECRET="CHANGE-THIS! VERY IMPORTANT!" +GOTRUE_JWT_EXP="3600" +GOTRUE_JWT_AUD="authenticated" +GOTRUE_JWT_DEFAULT_GROUP_NAME="authenticated" +GOTRUE_JWT_ADMIN_ROLES="supabase_admin,service_role" + +# Database & API connection details +GOTRUE_DB_DRIVER="postgres" +DB_NAMESPACE="auth" +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" +API_EXTERNAL_URL="http://localhost:9999" +GOTRUE_API_HOST="localhost" +PORT="9999" + +# SMTP config (generate credentials for signup to work) +GOTRUE_SMTP_HOST="" +GOTRUE_SMTP_PORT="587" +GOTRUE_SMTP_USER="" +GOTRUE_SMTP_MAX_FREQUENCY="5s" +GOTRUE_SMTP_PASS="" +GOTRUE_SMTP_ADMIN_EMAIL="" +GOTRUE_SMTP_SENDER_NAME="" + +# Mailer config +GOTRUE_MAILER_AUTOCONFIRM="true" +GOTRUE_MAILER_URLPATHS_CONFIRMATION="/verify" +GOTRUE_MAILER_URLPATHS_INVITE="/verify" +GOTRUE_MAILER_URLPATHS_RECOVERY="/verify" +GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE="/verify" +GOTRUE_MAILER_SUBJECTS_CONFIRMATION="Confirm Your Email" +GOTRUE_MAILER_SUBJECTS_RECOVERY="Reset Your Password" +GOTRUE_MAILER_SUBJECTS_MAGIC_LINK="Your Magic Link" +GOTRUE_MAILER_SUBJECTS_EMAIL_CHANGE="Confirm Email Change" +GOTRUE_MAILER_SUBJECTS_INVITE="You have been invited" +GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED="true" + +# Custom mailer template config +GOTRUE_MAILER_TEMPLATES_INVITE="" +GOTRUE_MAILER_TEMPLATES_CONFIRMATION="" +GOTRUE_MAILER_TEMPLATES_RECOVERY="" +GOTRUE_MAILER_TEMPLATES_MAGIC_LINK="" +GOTRUE_MAILER_TEMPLATES_EMAIL_CHANGE="" + +# Signup config +GOTRUE_DISABLE_SIGNUP="false" +GOTRUE_SITE_URL="http://localhost:3000" +GOTRUE_EXTERNAL_EMAIL_ENABLED="true" +GOTRUE_EXTERNAL_PHONE_ENABLED="true" +GOTRUE_EXTERNAL_IOS_BUNDLE_ID="com.supabase.auth" + +# Whitelist redirect to URLs here, a comma separated list of URIs (e.g. "https://foo.example.com,https://*.foo.example.com,https://bar.example.com") +GOTRUE_URI_ALLOW_LIST="http://localhost:3000" + +# Apple OAuth config +GOTRUE_EXTERNAL_APPLE_ENABLED="false" +GOTRUE_EXTERNAL_APPLE_CLIENT_ID="" +GOTRUE_EXTERNAL_APPLE_SECRET="" +GOTRUE_EXTERNAL_APPLE_REDIRECT_URI="http://localhost:9999/callback" + +# Azure OAuth config +GOTRUE_EXTERNAL_AZURE_ENABLED="false" +GOTRUE_EXTERNAL_AZURE_CLIENT_ID="" +GOTRUE_EXTERNAL_AZURE_SECRET="" +GOTRUE_EXTERNAL_AZURE_REDIRECT_URI="https://localhost:9999/callback" + +# Bitbucket OAuth config +GOTRUE_EXTERNAL_BITBUCKET_ENABLED="false" +GOTRUE_EXTERNAL_BITBUCKET_CLIENT_ID="" +GOTRUE_EXTERNAL_BITBUCKET_SECRET="" +GOTRUE_EXTERNAL_BITBUCKET_REDIRECT_URI="http://localhost:9999/callback" + +# Discord OAuth config +GOTRUE_EXTERNAL_DISCORD_ENABLED="false" +GOTRUE_EXTERNAL_DISCORD_CLIENT_ID="" +GOTRUE_EXTERNAL_DISCORD_SECRET="" +GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI="https://localhost:9999/callback" + +# Facebook OAuth config +GOTRUE_EXTERNAL_FACEBOOK_ENABLED="false" +GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID="" +GOTRUE_EXTERNAL_FACEBOOK_SECRET="" +GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI="https://localhost:9999/callback" + +# Figma OAuth config +GOTRUE_EXTERNAL_FIGMA_ENABLED="false" +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID="" +GOTRUE_EXTERNAL_FIGMA_SECRET="" +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI="https://localhost:9999/callback" + +# Gitlab OAuth config +GOTRUE_EXTERNAL_GITLAB_ENABLED="false" +GOTRUE_EXTERNAL_GITLAB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITLAB_SECRET="" +GOTRUE_EXTERNAL_GITLAB_REDIRECT_URI="http://localhost:9999/callback" + +# Google OAuth config +GOTRUE_EXTERNAL_GOOGLE_ENABLED="false" +GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID="" +GOTRUE_EXTERNAL_GOOGLE_SECRET="" +GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI="http://localhost:9999/callback" + +# Github OAuth config +GOTRUE_EXTERNAL_GITHUB_ENABLED="false" +GOTRUE_EXTERNAL_GITHUB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITHUB_SECRET="" +GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI="http://localhost:9999/callback" + +# Kakao OAuth config +GOTRUE_EXTERNAL_KAKAO_ENABLED="false" +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID="" +GOTRUE_EXTERNAL_KAKAO_SECRET="" +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI="http://localhost:9999/callback" + +# Notion OAuth config +GOTRUE_EXTERNAL_NOTION_ENABLED="false" +GOTRUE_EXTERNAL_NOTION_CLIENT_ID="" +GOTRUE_EXTERNAL_NOTION_SECRET="" +GOTRUE_EXTERNAL_NOTION_REDIRECT_URI="https://localhost:9999/callback" + +# Twitter OAuth1 config +GOTRUE_EXTERNAL_TWITTER_ENABLED="false" +GOTRUE_EXTERNAL_TWITTER_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITTER_SECRET="" +GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI="http://localhost:9999/callback" + +# Twitch OAuth config +GOTRUE_EXTERNAL_TWITCH_ENABLED="false" +GOTRUE_EXTERNAL_TWITCH_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITCH_SECRET="" +GOTRUE_EXTERNAL_TWITCH_REDIRECT_URI="http://localhost:9999/callback" + +# Spotify OAuth config +GOTRUE_EXTERNAL_SPOTIFY_ENABLED="false" +GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID="" +GOTRUE_EXTERNAL_SPOTIFY_SECRET="" +GOTRUE_EXTERNAL_SPOTIFY_REDIRECT_URI="http://localhost:9999/callback" + +# Keycloak OAuth config +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED="false" +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID="" +GOTRUE_EXTERNAL_KEYCLOAK_SECRET="" +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI="http://localhost:9999/callback" +GOTRUE_EXTERNAL_KEYCLOAK_URL="https://keycloak.example.com/auth/realms/myrealm" + +# Linkedin OAuth config +GOTRUE_EXTERNAL_LINKEDIN_ENABLED="true" +GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID="" +GOTRUE_EXTERNAL_LINKEDIN_SECRET="" + +# Slack OAuth config +GOTRUE_EXTERNAL_SLACK_ENABLED="false" +GOTRUE_EXTERNAL_SLACK_CLIENT_ID="" +GOTRUE_EXTERNAL_SLACK_SECRET="" +GOTRUE_EXTERNAL_SLACK_REDIRECT_URI="http://localhost:9999/callback" + +# WorkOS OAuth config +GOTRUE_EXTERNAL_WORKOS_ENABLED="true" +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID="" +GOTRUE_EXTERNAL_WORKOS_SECRET="" +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI="http://localhost:9999/callback" + +# Zoom OAuth config +GOTRUE_EXTERNAL_ZOOM_ENABLED="false" +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID="" +GOTRUE_EXTERNAL_ZOOM_SECRET="" +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback" + +# Sign in with Solana +GOTRUE_EXTERNAL_WEB3_SOLANA_ENABLED="true" +GOTRUE_EXTERNAL_WEB3_SOLANA_MAXIMUM_VALIDITY_DURATION="10m" + +# Anonymous auth config +GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED="false" + +# PKCE Config +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" + +# Phone provider config +GOTRUE_SMS_AUTOCONFIRM="false" +GOTRUE_SMS_MAX_FREQUENCY="5s" +GOTRUE_SMS_OTP_EXP="6000" +GOTRUE_SMS_OTP_LENGTH="6" +GOTRUE_SMS_PROVIDER="twilio" +GOTRUE_SMS_TWILIO_ACCOUNT_SID="" +GOTRUE_SMS_TWILIO_AUTH_TOKEN="" +GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID="" +GOTRUE_SMS_TEMPLATE="This is from supabase. Your code is {{ .Code }} ." +GOTRUE_SMS_MESSAGEBIRD_ACCESS_KEY="" +GOTRUE_SMS_MESSAGEBIRD_ORIGINATOR="" +GOTRUE_SMS_TEXTLOCAL_API_KEY="" +GOTRUE_SMS_TEXTLOCAL_SENDER="" +GOTRUE_SMS_VONAGE_API_KEY="" +GOTRUE_SMS_VONAGE_API_SECRET="" +GOTRUE_SMS_VONAGE_FROM="" + +# Captcha config +GOTRUE_SECURITY_CAPTCHA_ENABLED="false" +GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" +GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" +GOTRUE_SESSION_KEY="" + +# SAML config +GOTRUE_EXTERNAL_SAML_ENABLED="true" +GOTRUE_EXTERNAL_SAML_METADATA_URL="" +GOTRUE_EXTERNAL_SAML_API_BASE="http://localhost:9999" +GOTRUE_EXTERNAL_SAML_NAME="auth0" +GOTRUE_EXTERNAL_SAML_SIGNING_CERT="" +GOTRUE_EXTERNAL_SAML_SIGNING_KEY="" + +# Additional Security config +GOTRUE_LOG_LEVEL="debug" +GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false" +GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0" +GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false" +GOTRUE_OPERATOR_TOKEN="unused-operator-token" +GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For" +GOTRUE_RATE_LIMIT_EMAIL_SENT="100" + +GOTRUE_MAX_VERIFIED_FACTORS=10 + +# Auth Hook Configuration +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED=false +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRET="" + +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_ENABLED=false +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_SECRET="" + + +# Test OTP Config +GOTRUE_SMS_TEST_OTP=":, :..." +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="2050-01-01T01:00:00Z" # (e.g. 2023-09-29T08:14:06Z) diff --git a/www/index.html b/internal/reloader/testdata/backups/empty.backup similarity index 100% rename from www/index.html rename to internal/reloader/testdata/backups/empty.backup diff --git a/internal/reloader/testdata/empty.env.example b/internal/reloader/testdata/empty.env.example new file mode 100644 index 000000000..e69de29bb diff --git a/internal/security/captcha.go b/internal/security/captcha.go new file mode 100644 index 000000000..aeacb6338 --- /dev/null +++ b/internal/security/captcha.go @@ -0,0 +1,101 @@ +package security + +import ( + "encoding/json" + "log" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "fmt" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/utilities" +) + +type GotrueRequest struct { + Security GotrueSecurity `json:"gotrue_meta_security"` +} + +type GotrueSecurity struct { + Token string `json:"captcha_token"` +} + +type VerificationResponse struct { + Success bool `json:"success"` + ErrorCodes []string `json:"error-codes"` + Hostname string `json:"hostname"` +} + +var Client *http.Client + +func init() { + var defaultTimeout time.Duration = time.Second * 10 + timeoutStr := os.Getenv("GOTRUE_SECURITY_CAPTCHA_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_SECURITY_CAPTCHA_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } + + Client = &http.Client{Timeout: defaultTimeout} +} + +func VerifyRequest(requestBody *GotrueRequest, clientIP, secretKey, captchaProvider string) (VerificationResponse, error) { + captchaResponse := strings.TrimSpace(requestBody.Security.Token) + + if captchaResponse == "" { + return VerificationResponse{}, errors.New("no captcha response (captcha_token) found in request") + } + + captchaURL, err := GetCaptchaURL(captchaProvider) + if err != nil { + return VerificationResponse{}, err + } + + return verifyCaptchaCode(captchaResponse, secretKey, clientIP, captchaURL) +} + +func verifyCaptchaCode(token, secretKey, clientIP, captchaURL string) (VerificationResponse, error) { + data := url.Values{} + data.Set("secret", secretKey) + data.Set("response", token) + data.Set("remoteip", clientIP) + // TODO (darora): pipe through sitekey + + r, err := http.NewRequest("POST", captchaURL, strings.NewReader(data.Encode())) + if err != nil { + return VerificationResponse{}, errors.Wrap(err, "couldn't initialize request object for captcha check") + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + res, err := Client.Do(r) + if err != nil { + return VerificationResponse{}, errors.Wrap(err, "failed to verify captcha response") + } + defer utilities.SafeClose(res.Body) + + var verificationResponse VerificationResponse + + if err := json.NewDecoder(res.Body).Decode(&verificationResponse); err != nil { + return VerificationResponse{}, errors.Wrap(err, "failed to decode captcha response: not JSON") + } + + return verificationResponse, nil +} + +func GetCaptchaURL(captchaProvider string) (string, error) { + switch captchaProvider { + case "hcaptcha": + return "https://hcaptcha.com/siteverify", nil + case "turnstile": + return "https://challenges.cloudflare.com/turnstile/v0/siteverify", nil + default: + return "", fmt.Errorf("captcha Provider %q could not be found", captchaProvider) + } +} diff --git a/internal/storage/dial.go b/internal/storage/dial.go new file mode 100644 index 000000000..3ee99395b --- /dev/null +++ b/internal/storage/dial.go @@ -0,0 +1,192 @@ +package storage + +import ( + "context" + "database/sql" + "net/url" + "reflect" + "time" + + "github.com/XSAM/otelsql" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/columns" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +// Connection is the interface a storage provider must implement. +type Connection struct { + *pop.Connection +} + +// Dial will connect to that storage engine +func Dial(config *conf.GlobalConfiguration) (*Connection, error) { + if config.DB.Driver == "" && config.DB.URL != "" { + u, err := url.Parse(config.DB.URL) + if err != nil { + return nil, errors.Wrap(err, "parsing db connection url") + } + config.DB.Driver = u.Scheme + } + + driver := "" + if config.DB.Driver != "postgres" { + logrus.Warn("DEPRECATION NOTICE: only PostgreSQL is supported by Supabase's GoTrue, will be removed soon") + } else { + // pop v5 uses pgx as the default PostgreSQL driver + driver = "pgx" + } + + if driver != "" && (config.Tracing.Enabled || config.Metrics.Enabled) { + instrumentedDriver, err := otelsql.Register(driver) + if err != nil { + logrus.WithError(err).Errorf("unable to instrument sql driver %q for use with OpenTelemetry", driver) + } else { + logrus.Debugf("using %s as an instrumented driver for OpenTelemetry", instrumentedDriver) + + // sqlx needs to be informed that the new instrumented + // driver has the same semantics as the + // non-instrumented driver + sqlx.BindDriver(instrumentedDriver, sqlx.BindType(driver)) + + driver = instrumentedDriver + } + } + + options := make(map[string]string) + + if config.DB.HealthCheckPeriod != time.Duration(0) { + options["pool_health_check_period"] = config.DB.HealthCheckPeriod.String() + } + + if config.DB.ConnMaxIdleTime != time.Duration(0) { + options["pool_max_conn_idle_time"] = config.DB.ConnMaxIdleTime.String() + } + + db, err := pop.NewConnection(&pop.ConnectionDetails{ + Dialect: config.DB.Driver, + Driver: driver, + URL: config.DB.URL, + Pool: config.DB.MaxPoolSize, + IdlePool: config.DB.MaxIdlePoolSize, + ConnMaxLifetime: config.DB.ConnMaxLifetime, + ConnMaxIdleTime: config.DB.ConnMaxIdleTime, + Options: options, + }) + if err != nil { + return nil, errors.Wrap(err, "opening database connection") + } + if err := db.Open(); err != nil { + return nil, errors.Wrap(err, "checking database connection") + } + + if config.Metrics.Enabled { + registerOpenTelemetryDatabaseStats(db) + } + + return &Connection{db}, nil +} + +func registerOpenTelemetryDatabaseStats(db *pop.Connection) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection -- panicked") + } + }() + + dbval := reflect.Indirect(reflect.ValueOf(db.Store)) + dbfield := dbval.Field(0) + sqldbfield := reflect.Indirect(dbfield).Field(0) + + sqldb, ok := sqldbfield.Interface().(*sql.DB) + if !ok || sqldb == nil { + logrus.Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection") + return + } + + if err := otelsql.RegisterDBStatsMetrics(sqldb); err != nil { + logrus.WithError(err).Error("unable to register OpenTelemetry stats metrics for databse") + } else { + logrus.Debug("registered OpenTelemetry stats metrics for database") + } +} + +type CommitWithError struct { + Err error +} + +func (e *CommitWithError) Error() string { + return e.Err.Error() +} + +func (e *CommitWithError) Cause() error { + return e.Err +} + +// NewCommitWithError creates an error that can be returned in a pop transaction +// without rolling back the transaction. This should only be used in cases where +// you want the transaction to commit but return an error message to the user. +func NewCommitWithError(err error) *CommitWithError { + return &CommitWithError{Err: err} +} + +func (c *Connection) Transaction(fn func(*Connection) error) error { + if c.TX == nil { + var returnErr error + if terr := c.Connection.Transaction(func(tx *pop.Connection) error { + err := fn(&Connection{tx}) + switch err.(type) { + case *CommitWithError: + returnErr = err + return nil + default: + return err + } + }); terr != nil { + // there exists a race condition when the context deadline is exceeded + // and whether the transaction has been committed or not + // e.g. if the context deadline has exceeded but the transaction has already been committed, + // it won't be possible to perform a rollback on the transaction since the transaction has been closed + if !errors.Is(terr, sql.ErrTxDone) { + return terr + } + } + return returnErr + } + return fn(c) +} + +// WithContext returns a new connection with an updated context. This is +// typically used for tracing as the context contains trace span information. +func (c *Connection) WithContext(ctx context.Context) *Connection { + return &Connection{c.Connection.WithContext(ctx)} +} + +func getExcludedColumns(model interface{}, includeColumns ...string) ([]string, error) { + sm := &pop.Model{Value: model} + st := reflect.TypeOf(model) + if st.Kind() == reflect.Ptr { + _ = st.Elem() + } + + // get all columns and remove included to get excluded set + cols := columns.ForStructWithAlias(model, sm.TableName(), sm.As, sm.IDField()) + for _, f := range includeColumns { + if _, ok := cols.Cols[f]; !ok { + return nil, errors.Errorf("Invalid column name %s", f) + } + cols.Remove(f) + } + + xcols := make([]string, 0, len(cols.Cols)) + for n := range cols.Cols { + // gobuffalo updates the updated_at column automatically + if n == "updated_at" { + continue + } + xcols = append(xcols, n) + } + return xcols, nil +} diff --git a/internal/storage/dial_test.go b/internal/storage/dial_test.go new file mode 100644 index 000000000..078b6d57a --- /dev/null +++ b/internal/storage/dial_test.go @@ -0,0 +1,60 @@ +package storage + +import ( + "errors" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +type TestUser struct { + ID uuid.UUID + Role string `db:"role"` + Other string `db:"othercol"` +} + +func TestGetExcludedColumns(t *testing.T) { + u := TestUser{} + cols, err := getExcludedColumns(u, "role") + require.NoError(t, err) + require.NotContains(t, cols, "role") + require.Contains(t, cols, "othercol") +} + +func TestGetExcludedColumns_InvalidName(t *testing.T) { + u := TestUser{} + _, err := getExcludedColumns(u, "adsf") + require.Error(t, err) +} + +func TestTransaction(t *testing.T) { + apiTestConfig := "../../hack/test.env" + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + conn, err := Dial(config) + require.NoError(t, err) + require.NotNil(t, conn) + + defer func() { + // clean up the test table created + require.NoError(t, conn.RawQuery("drop table if exists test").Exec(), "Error removing table") + }() + + commitWithError := NewCommitWithError(errors.New("commit with error")) + err = conn.Transaction(func(tx *Connection) error { + require.NoError(t, tx.RawQuery("create table if not exists test()").Exec(), "Error saving creating test table") + return commitWithError + }) + require.Error(t, err) + require.ErrorIs(t, err, commitWithError) + + type TestData struct{} + + // check that transaction is still being committed despite returning an error above + data := []TestData{} + err = conn.RawQuery("select * from test").All(&data) + require.NoError(t, err) + require.Empty(t, data) +} diff --git a/storage/helper.go b/internal/storage/helper.go similarity index 79% rename from storage/helper.go rename to internal/storage/helper.go index 2df3a1675..23599840a 100644 --- a/storage/helper.go +++ b/internal/storage/helper.go @@ -14,7 +14,7 @@ func (s *NullString) Scan(value interface{}) error { } strVal, ok := value.(string) if !ok { - return errors.New("Column is not a string") + return errors.New("column is not a string") } *s = NullString(strVal) return nil @@ -25,3 +25,7 @@ func (s NullString) Value() (driver.Value, error) { } return string(s), nil } + +func (s NullString) String() string { + return string(s) +} diff --git a/storage/sql.go b/internal/storage/sql.go similarity index 100% rename from storage/sql.go rename to internal/storage/sql.go diff --git a/storage/test/db_setup.go b/internal/storage/test/db_setup.go similarity index 64% rename from storage/test/db_setup.go rename to internal/storage/test/db_setup.go index 817e46274..8eeb09998 100644 --- a/storage/test/db_setup.go +++ b/internal/storage/test/db_setup.go @@ -1,8 +1,8 @@ package test import ( - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" ) func SetupDBConnection(globalConfig *conf.GlobalConfiguration) (*storage.Connection, error) { diff --git a/internal/utilities/context.go b/internal/utilities/context.go new file mode 100644 index 000000000..06aa74a39 --- /dev/null +++ b/internal/utilities/context.go @@ -0,0 +1,51 @@ +package utilities + +import ( + "context" + "sync" +) + +type contextKey string + +func (c contextKey) String() string { + return "gotrue api context key " + string(c) +} + +const ( + requestIDKey = contextKey("request_id") +) + +// WithRequestID adds the provided request ID to the context. +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +// GetRequestID reads the request ID from the context. +func GetRequestID(ctx context.Context) string { + obj := ctx.Value(requestIDKey) + if obj == nil { + return "" + } + + return obj.(string) +} + +// WaitForCleanup waits until all long-running goroutines shut +// down cleanly or until the provided context signals done. +func WaitForCleanup(ctx context.Context, wg *sync.WaitGroup) { + cleanupDone := make(chan struct{}) + + go func() { + defer close(cleanupDone) + + wg.Wait() + }() + + select { + case <-ctx.Done(): + return + + case <-cleanupDone: + return + } +} diff --git a/internal/utilities/hibpcache.go b/internal/utilities/hibpcache.go new file mode 100644 index 000000000..14c3fc3ad --- /dev/null +++ b/internal/utilities/hibpcache.go @@ -0,0 +1,76 @@ +package utilities + +import ( + "context" + "sync" + + "github.com/bits-and-blooms/bloom/v3" +) + +const ( + // hibpHashLength is the length of a hex-encoded SHA1 hash. + hibpHashLength = 40 + // hibpHashPrefixLength is the length of the hashed password prefix. + hibpHashPrefixLength = 5 +) + +type HIBPBloomCache struct { + sync.RWMutex + + n uint + items uint + filter *bloom.BloomFilter +} + +func NewHIBPBloomCache(n uint, fp float64) *HIBPBloomCache { + cache := &HIBPBloomCache{ + n: n, + filter: bloom.NewWithEstimates(n, fp), + } + + return cache +} + +func (c *HIBPBloomCache) Cap() uint { + return c.filter.Cap() +} + +func (c *HIBPBloomCache) Add(ctx context.Context, prefix []byte, suffixes [][]byte) error { + c.Lock() + defer c.Unlock() + + c.items += uint(len(suffixes)) + + if c.items > (4*c.n)/5 { + // clear the filter if 80% full to keep the actual false + // positive rate low + c.filter.ClearAll() + + // reduce memory footprint when this happens + c.filter.BitSet().Compact() + + c.items = uint(len(suffixes)) + } + + var combined [hibpHashLength]byte + copy(combined[:], prefix) + + for _, suffix := range suffixes { + copy(combined[hibpHashPrefixLength:], suffix) + + c.filter.Add(combined[:]) + } + + return nil +} + +func (c *HIBPBloomCache) Contains(ctx context.Context, prefix, suffix []byte) (bool, error) { + var combined [hibpHashLength]byte + copy(combined[:], prefix) + copy(combined[hibpHashPrefixLength:], suffix) + + c.RLock() + defer c.RUnlock() + + return c.filter.Test(combined[:]), nil +} diff --git a/internal/utilities/io.go b/internal/utilities/io.go new file mode 100644 index 000000000..ab89b4c34 --- /dev/null +++ b/internal/utilities/io.go @@ -0,0 +1,13 @@ +package utilities + +import ( + "io" + + "github.com/sirupsen/logrus" +) + +func SafeClose(closer io.Closer) { + if err := closer.Close(); err != nil { + logrus.WithError(err).Warn("Close operation failed") + } +} diff --git a/internal/utilities/postgres.go b/internal/utilities/postgres.go new file mode 100644 index 000000000..4d7fde8d0 --- /dev/null +++ b/internal/utilities/postgres.go @@ -0,0 +1,76 @@ +package utilities + +import ( + "errors" + "strconv" + "strings" + + "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" +) + +// PostgresError is a custom error struct for marshalling Postgres errors to JSON. +type PostgresError struct { + Code string `json:"code"` + HttpStatusCode int `json:"-"` + Message string `json:"message"` + Hint string `json:"hint,omitempty"` + Detail string `json:"detail,omitempty"` +} + +// NewPostgresError returns a new PostgresError if the error was from a publicly +// accessible Postgres error. +func NewPostgresError(err error) *PostgresError { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && isPubliclyAccessiblePostgresError(pgErr.Code) { + return &PostgresError{ + Code: pgErr.Code, + HttpStatusCode: getHttpStatusCodeFromPostgresErrorCode(pgErr.Code), + Message: pgErr.Message, + Detail: pgErr.Detail, + Hint: pgErr.Hint, + } + } + + return nil +} +func (pg *PostgresError) IsUniqueConstraintViolated() bool { + // See https://www.postgresql.org/docs/current/errcodes-appendix.html for list of error codes + return pg.Code == "23505" +} + +// isPubliclyAccessiblePostgresError checks if the Postgres error should be +// made accessible. +func isPubliclyAccessiblePostgresError(code string) bool { + if len(code) != 5 { + return false + } + + // default response + return getHttpStatusCodeFromPostgresErrorCode(code) != 0 +} + +// getHttpStatusCodeFromPostgresErrorCode maps a Postgres error code to a HTTP +// status code. Returns 0 if the code doesn't map to a given postgres error code. +func getHttpStatusCodeFromPostgresErrorCode(code string) int { + if code == pgerrcode.RaiseException || + code == pgerrcode.IntegrityConstraintViolation || + code == pgerrcode.RestrictViolation || + code == pgerrcode.NotNullViolation || + code == pgerrcode.ForeignKeyViolation || + code == pgerrcode.UniqueViolation || + code == pgerrcode.CheckViolation || + code == pgerrcode.ExclusionViolation { + return 500 + } + + // Use custom HTTP status code if Postgres error was triggered with `PTXXX` + // code. This is consistent with PostgREST's behaviour as well. + if strings.HasPrefix(code, "PT") { + if httpStatusCode, err := strconv.ParseInt(code[2:], 10, 0); err == nil { + return int(httpStatusCode) + } + } + + return 0 +} diff --git a/internal/utilities/request.go b/internal/utilities/request.go new file mode 100644 index 000000000..bfe0a51df --- /dev/null +++ b/internal/utilities/request.go @@ -0,0 +1,121 @@ +package utilities + +import ( + "bytes" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" +) + +// GetIPAddress returns the real IP address of the HTTP request. It parses the +// X-Forwarded-For header. +func GetIPAddress(r *http.Request) string { + if r.Header != nil { + xForwardedFor := r.Header.Get("X-Forwarded-For") + if xForwardedFor != "" { + ips := strings.Split(xForwardedFor, ",") + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + } + + for _, ip := range ips { + if ip != "" { + parsed := net.ParseIP(ip) + if parsed == nil { + continue + } + + return parsed.String() + } + } + } + } + + ipPort := r.RemoteAddr + ip, _, err := net.SplitHostPort(ipPort) + if err != nil { + return ipPort + } + + return ip +} + +// GetBodyBytes reads the whole request body properly into a byte array. +func GetBodyBytes(req *http.Request) ([]byte, error) { + if req.Body == nil || req.Body == http.NoBody { + return nil, nil + } + + originalBody := req.Body + defer SafeClose(originalBody) + + buf, err := io.ReadAll(originalBody) + if err != nil { + return nil, err + } + + req.Body = io.NopCloser(bytes.NewReader(buf)) + + return buf, nil +} + +func GetReferrer(r *http.Request, config *conf.GlobalConfiguration) string { + // try get redirect url from query or post data first + reqref := getRedirectTo(r) + if IsRedirectURLValid(config, reqref) { + return reqref + } + + // instead try referrer header value + reqref = r.Referer() + if IsRedirectURLValid(config, reqref) { + return reqref + } + + return config.SiteURL +} + +func IsRedirectURLValid(config *conf.GlobalConfiguration, redirectURL string) bool { + if redirectURL == "" { + return false + } + + base, berr := url.Parse(config.SiteURL) + refurl, rerr := url.Parse(redirectURL) + + // As long as the referrer came from the site, we will redirect back there + if berr == nil && rerr == nil && base.Hostname() == refurl.Hostname() { + return true + } + + // Clean up the referrer URL to avoid pattern matching an invalid URL + refurl.Fragment = "" + refurl.RawQuery = "" + + // For case when user came from mobile app or other permitted resource - redirect back + for _, pattern := range config.URIAllowListMap { + if pattern.Match(refurl.String()) { + return true + } + } + + return false +} + +// getRedirectTo tries extract redirect url from header or from query params +func getRedirectTo(r *http.Request) (reqref string) { + reqref = r.Header.Get("redirect_to") + if reqref != "" { + return + } + + if err := r.ParseForm(); err == nil { + reqref = r.Form.Get("redirect_to") + } + + return +} diff --git a/internal/utilities/request_test.go b/internal/utilities/request_test.go new file mode 100644 index 000000000..d08e86ae6 --- /dev/null +++ b/internal/utilities/request_test.go @@ -0,0 +1,144 @@ +package utilities + +import ( + "net/http" + "net/http/httptest" + tst "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestGetIPAddress(t *tst.T) { + examples := []func(r *http.Request) string{ + func(r *http.Request) string { + r.Header = nil + r.RemoteAddr = "127.0.0.1:8080" + + return "127.0.0.1" + }, + + func(r *http.Request) string { + r.Header = nil + r.RemoteAddr = "incorrect" + + return "incorrect" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + + return "127.0.0.1" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "[::1]:8080" + + return "::1" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2,") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2,127.0.0.3") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "::1,127.0.0.2") + + return "::1" + }, + } + + for _, example := range examples { + req := &http.Request{} + expected := example(req) + + require.Equal(t, GetIPAddress(req), expected) + } +} + +func TestGetReferrer(t *tst.T) { + config := conf.GlobalConfiguration{ + SiteURL: "https://example.com", + URIAllowList: []string{"http://localhost:8000/*", "http://*.localhost:8000/*"}, + JWT: conf.JWTConfiguration{ + Secret: "testsecret", + }, + } + require.NoError(t, config.ApplyDefaults()) + cases := []struct { + desc string + redirectURL string + expected string + }{ + { + desc: "valid redirect url", + redirectURL: "http://localhost:8000/path", + expected: "http://localhost:8000/path", + }, + { + desc: "invalid redirect url", + redirectURL: "http://localhost:3000", + expected: config.SiteURL, + }, + { + desc: "no / separator", + redirectURL: "http://localhost:8000", + expected: config.SiteURL, + }, + { + desc: "* respects separator", + redirectURL: "http://localhost:8000/path/to/page", + expected: config.SiteURL, + }, + { + desc: "* respects parameters", + redirectURL: "http://localhost:8000/path?param=1", + expected: "http://localhost:8000/path?param=1", + }, + { + desc: "invalid redirect url via query smurfing", + redirectURL: "http://123?.localhost:8000/path", + expected: config.SiteURL, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *tst.T) { + r := httptest.NewRequest("GET", "http://localhost?redirect_to="+c.redirectURL, nil) + referrer := GetReferrer(r, &config) + require.Equal(t, c.expected, referrer) + }) + } +} diff --git a/internal/utilities/siws/helpers.go b/internal/utilities/siws/helpers.go new file mode 100644 index 000000000..65bcca0df --- /dev/null +++ b/internal/utilities/siws/helpers.go @@ -0,0 +1,17 @@ +package siws + +import ( + "regexp" +) + +var domainPattern = regexp.MustCompile(`^(localhost|(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,})(?::\d{1,5})?$`) + +func IsValidDomain(domain string) bool { + return domainPattern.MatchString(domain) +} + +var validSolanaNetworksPattern = regexp.MustCompile("^solana:(main|dev|test|local)net$") + +func IsValidSolanaNetwork(network string) bool { + return validSolanaNetworksPattern.MatchString(network) +} diff --git a/internal/utilities/siws/helpers_test.go b/internal/utilities/siws/helpers_test.go new file mode 100644 index 000000000..57bb5fdc7 --- /dev/null +++ b/internal/utilities/siws/helpers_test.go @@ -0,0 +1 @@ +package siws diff --git a/internal/utilities/siws/parser.go b/internal/utilities/siws/parser.go new file mode 100644 index 000000000..dcb3fb36e --- /dev/null +++ b/internal/utilities/siws/parser.go @@ -0,0 +1,199 @@ +package siws + +import ( + "crypto/ed25519" + "errors" + "fmt" + "net/url" + "regexp" + "strings" + "time" + + "github.com/btcsuite/btcutil/base58" +) + +// SIWSMessage is the final structured form of a parsed SIWS message. +type SIWSMessage struct { + Raw string + + Domain string + Address string + Statement string + URI *url.URL + Version string + Nonce string + IssuedAt time.Time + ChainID string + NotBefore time.Time + RequestID string + ExpirationTime time.Time + Resources []*url.URL +} + +const headerSuffix = " wants you to sign in with your Solana account:" + +var addressPattern = regexp.MustCompile("^[a-zA-Z0-9]{32,44}$") + +func ParseMessage(raw string) (*SIWSMessage, error) { + lines := strings.Split(raw, "\n") + if len(lines) < 6 { + return nil, errors.New("siws: message needs at least 6 lines") + } + + // Parse first line exactly + header := lines[0] + if !strings.HasSuffix(header, headerSuffix) { + return nil, fmt.Errorf("siws: message first line does not end in %q", headerSuffix) + } + + domain := strings.TrimSpace(strings.TrimSuffix(header, headerSuffix)) + if !IsValidDomain(domain) { + return nil, errors.New("siws: domain in first line of message is not valid") + } + + address := strings.TrimSpace(lines[1]) + if !addressPattern.MatchString(address) { + return nil, errors.New("siws: wallet address is not in base58 format") + } + + msg := &SIWSMessage{ + Raw: raw, + Domain: domain, + Address: address, + } + + if lines[2] != "" { + return nil, errors.New("siws: third line must be empty") + } + + startIndex := 3 + if lines[3] != "" && lines[4] == "" { + msg.Statement = lines[3] + startIndex = 5 + } + + inResources := false + for i := startIndex; i < len(lines); i += 1 { + line := strings.TrimSpace(lines[i]) + + if inResources { + if strings.HasPrefix(line, "- ") { + resource := strings.TrimSpace(strings.TrimPrefix(line, "- ")) + + resourceURL, err := url.ParseRequestURI(resource) + if err != nil { + return nil, fmt.Errorf("siws: Resource at position %d has invalid URI", len(msg.Resources)) + } + + msg.Resources = append(msg.Resources, resourceURL) + continue + } else { + inResources = false + } + } + + if line == "Resources:" { + inResources = true + continue + } + + if line == "" { + continue + } + + key, value, found := strings.Cut(line, ":") + if !found { + return nil, fmt.Errorf("siws: encountered unparsable line at index %d", i) + } + + value = strings.TrimSpace(value) + + switch key { + case "URI": + uri, err := url.ParseRequestURI(value) + if err != nil { + return nil, errors.New("siws: URI is not valid") + } + + msg.URI = uri + + case "Version": + msg.Version = value + + case "Chain ID": + msg.ChainID = value + + case "Nonce": + msg.Nonce = value + + case "Issued At": + ts, err := time.Parse(time.RFC3339, value) + if err != nil { + ts, err = time.Parse(time.RFC3339Nano, value) + if err != nil { + return nil, errors.New("siws: Issued At is not a valid ISO8601 timestamp") + } + } + msg.IssuedAt = ts + + case "Expiration Time": + ts, err := time.Parse(time.RFC3339, value) + if err != nil { + ts, err = time.Parse(time.RFC3339Nano, value) + if err != nil { + return nil, errors.New("siws: Expiration Time is not a valid ISO8601 timestamp") + } + } + msg.ExpirationTime = ts + + case "Not Before": + ts, err := time.Parse(time.RFC3339, value) + if err != nil { + ts, err = time.Parse(time.RFC3339Nano, value) + if err != nil { + return nil, errors.New("siws: Not Before is not a valid ISO8601 timestamp") + } + } + msg.NotBefore = ts + + case "Request ID": + msg.RequestID = value + } + } + + if msg.Version != "1" { + return nil, fmt.Errorf("siws: Version value is not supported, expected 1 got %q", msg.Version) + } + + if msg.IssuedAt.IsZero() { + return nil, errors.New("siws: Issued At is not specified") + } + + if msg.URI == nil { + return nil, errors.New("siws: URI is not specified") + } + + if msg.ChainID != "" && !IsValidSolanaNetwork(msg.ChainID) { + return nil, errors.New("siws: Chain ID is not valid") + } + + if !msg.IssuedAt.IsZero() && !msg.ExpirationTime.IsZero() { + if msg.IssuedAt.After(msg.ExpirationTime) { + return nil, errors.New("siws: Issued At is after Expiration Time") + } + } + + if !msg.NotBefore.IsZero() && !msg.ExpirationTime.IsZero() { + if msg.NotBefore.After(msg.ExpirationTime) { + return nil, errors.New("siws: Not Before is after Expiration Time") + } + } + + return msg, nil +} + +func (m *SIWSMessage) VerifySignature(signature []byte) bool { + pubKey := base58.Decode(m.Address) + + return ed25519.Verify(pubKey, []byte(m.Raw), signature) +} diff --git a/internal/utilities/siws/parser_test.go b/internal/utilities/siws/parser_test.go new file mode 100644 index 000000000..f979fb426 --- /dev/null +++ b/internal/utilities/siws/parser_test.go @@ -0,0 +1,126 @@ +package siws + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseMessage(t *testing.T) { + negativeExamples := []struct { + example string + error string + }{ + { + example: "", + error: "message needs at least 6 lines", + }, + { + example: "\n\n\n\n", + error: "message needs at least 6 lines", + }, + { + example: "domain.com whatever\n\n\n\n\n\n", + error: "message first line does not end in \" wants you to sign in with your Solana account:\"", + }, + { + example: "******* wants you to sign in with your Solana account:\n\n\n\n\n\n", + error: "domain in first line of message is not valid", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n***************************************\n\n\n\n\n", + error: "wallet address is not in base58 format", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\nURI: https://google.com\n\n\n", + error: "third line must be empty", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nNot Parsable\n", + error: "encountered unparsable line at index 5", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: ***\nIssued At: 2025-01-01T00:00:00Z", + error: "URI is not valid", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://google.com\nIssued At: not-a-timestamp", + error: "Issued At is not a valid ISO8601 timestamp", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://google.com\nIssued At: 2025-01-01T00:00:00Z\nExpiration Time: not-a-timestamp", + error: "Expiration Time is not a valid ISO8601 timestamp", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://google.com\nIssued At: 2025-01-01T00:00:00Z\nNot Before: not-a-timestamp", + error: "Not Before is not a valid ISO8601 timestamp", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 2\nIssued At: 2025-01-01T00:00:00Z\nURI: https://google.com\n", + error: "Version value is not supported, expected 1 got \"2\"", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nIssued At: 2025-01-01T00:00:00Z\n\n", + error: "URI is not specified", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://domain.com\nResources:\n- https://google.com\n", + error: "Issued At is not specified", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-02T00:00:00Z\nExpiration Time: 2025-01-01T00:00:00Z\n", + error: "Issued At is after Expiration Time", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-01T00:00:00Z\nExpiration Time: 2025-01-02T00:00:00Z\nNot Before: 2025-01-03T00:00:00Z\n", + error: "Not Before is after Expiration Time", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-01T00:00:00Z\nResources:\n- https://google.com\n- ***\n", + error: "Resource at position 1 has invalid URI", + }, + { + example: "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-01T00:00:00Z\nChain ID: random:mainnet", + error: "Chain ID is not valid", + }, + } + + for i, example := range negativeExamples { + _, err := ParseMessage(example.example) + + t.Run(fmt.Sprintf("negative example %d", i), func(t *testing.T) { + require.NotNil(t, err) + require.Equal(t, "siws: "+example.error, err.Error()) + }) + } + + positiveExamples := []string{ + "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nStatement\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-01T00:00:00Z\nNonce: 123\nRequest ID: abcdef\nChain ID: solana:testnet", + "domain.com wants you to sign in with your Solana account:\n4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR\n\nVersion: 1\nURI: https://domain.com\nIssued At: 2025-01-01T00:00:00Z\nNonce: 123\nRequest ID: abcdef\nChain ID: solana:testnet", + } + + for i, example := range positiveExamples { + t.Run(fmt.Sprintf("positive example %d", i), func(t *testing.T) { + parsed, err := ParseMessage(example) + + require.Nil(t, err) + require.Equal(t, "domain.com", parsed.Domain) + require.Equal(t, "4Cw1koUQtqybLFem7uqhzMBznMPGARbFS4cjaYbM9RnR", parsed.Address) + + if i == 0 { + require.Equal(t, "Statement", parsed.Statement) + } else { + require.Equal(t, "", parsed.Statement) + } + + require.Equal(t, "2025-01-01 00:00:00 +0000 UTC", parsed.IssuedAt.String()) + require.Equal(t, "https://domain.com", parsed.URI.String()) + require.Equal(t, "solana:testnet", parsed.ChainID) + require.Equal(t, "123", parsed.Nonce) + require.Equal(t, "abcdef", parsed.RequestID) + + require.Equal(t, false, parsed.VerifySignature(make([]byte, 64))) + }) + } +} diff --git a/internal/utilities/version.go b/internal/utilities/version.go new file mode 100644 index 000000000..b3ba95a28 --- /dev/null +++ b/internal/utilities/version.go @@ -0,0 +1,4 @@ +package utilities + +// Version is git commit or release tag from which this binary was built. +var Version string diff --git a/mailer/mailer.go b/mailer/mailer.go deleted file mode 100644 index 496572760..000000000 --- a/mailer/mailer.go +++ /dev/null @@ -1,83 +0,0 @@ -package mailer - -import ( - "net/url" - "regexp" - - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/mailme" - "github.com/sirupsen/logrus" - "gopkg.in/gomail.v2" -) - -// Mailer defines the interface a mailer must implement. -type Mailer interface { - Send(user *models.User, subject, body string, data map[string]interface{}) error - InviteMail(user *models.User, referrerURL string) error - ConfirmationMail(user *models.User, referrerURL string) error - RecoveryMail(user *models.User, referrerURL string) error - MagicLinkMail(user *models.User, referrerURL string) error - EmailChangeMail(user *models.User, referrerURL string) error - ValidateEmail(email string) error - GetEmailActionLink(user *models.User, actionType, referrerURL string) (string, error) -} - -// NewMailer returns a new gotrue mailer -func NewMailer(instanceConfig *conf.Configuration) Mailer { - if instanceConfig.SMTP.Host == "" { - logrus.Infof("Noop mailer being used for %v", instanceConfig.SiteURL) - return &noopMailer{} - } - - mail := gomail.NewMessage() - from := mail.FormatAddress(instanceConfig.SMTP.AdminEmail, instanceConfig.SMTP.SenderName) - - return &TemplateMailer{ - SiteURL: instanceConfig.SiteURL, - Config: instanceConfig, - Mailer: &mailme.Mailer{ - Host: instanceConfig.SMTP.Host, - Port: instanceConfig.SMTP.Port, - User: instanceConfig.SMTP.User, - Pass: instanceConfig.SMTP.Pass, - From: from, - BaseURL: instanceConfig.SiteURL, - Logger: logrus.New(), - }, - } -} - -func withDefault(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func getSiteURL(referrerURL, siteURL, filepath, fragment string) (string, error) { - baseURL := siteURL - if filepath == "" && referrerURL != "" { - baseURL = referrerURL - } - - site, err := url.Parse(baseURL) - if err != nil { - return "", err - } - if filepath != "" { - path, err := url.Parse(filepath) - if err != nil { - return "", err - } - site = site.ResolveReference(path) - } - site.RawQuery = fragment - return site.String(), nil -} - -var urlRegexp = regexp.MustCompile(`^https?://[^/]+`) - -func enforceRelativeURL(url string) string { - return urlRegexp.ReplaceAllString(url, "") -} diff --git a/mailer/mailer_test.go b/mailer/mailer_test.go deleted file mode 100644 index a8a74adec..000000000 --- a/mailer/mailer_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package mailer - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetSiteURL(t *testing.T) { - cases := []struct { - ReferrerURL string - SiteURL string - Path string - Fragment string - Expected string - }{ - {"", "https://test.example.com", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"}, - {"", "https://test.example.com/removedpath", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"}, - {"", "https://test.example.com/", "/trailingslash/", "", "https://test.example.com/trailingslash/"}, - {"", "https://test.example.com", "f", "fragment", "https://test.example.com/f?fragment"}, - {"https://test.example.com/admin", "https://test.example.com", "", "fragment", "https://test.example.com/admin?fragment"}, - {"https://test.example.com/admin", "https://test.example.com", "f", "fragment", "https://test.example.com/f?fragment"}, - {"", "https://test.example.com", "", "fragment", "https://test.example.com?fragment"}, - } - - for _, c := range cases { - act, err := getSiteURL(c.ReferrerURL, c.SiteURL, c.Path, c.Fragment) - assert.NoError(t, err, c.Expected) - assert.Equal(t, c.Expected, act) - } -} - -func TestRelativeURL(t *testing.T) { - cases := []struct { - URL string - Expected string - }{ - {"https://test.example.com", ""}, - {"http://test.example.com", ""}, - {"test.example.com", "test.example.com"}, - {"/some/path#fragment", "/some/path#fragment"}, - } - - for _, c := range cases { - res := enforceRelativeURL(c.URL) - assert.Equal(t, c.Expected, res, c.URL) - } -} diff --git a/mailer/noop.go b/mailer/noop.go deleted file mode 100644 index 9c486fc4e..000000000 --- a/mailer/noop.go +++ /dev/null @@ -1,38 +0,0 @@ -package mailer - -import "github.com/netlify/gotrue/models" - -type noopMailer struct { -} - -func (m noopMailer) ValidateEmail(email string) error { - return nil -} - -func (m *noopMailer) InviteMail(user *models.User, referrerURL string) error { - return nil -} - -func (m *noopMailer) ConfirmationMail(user *models.User, referrerURL string) error { - return nil -} - -func (m noopMailer) RecoveryMail(user *models.User, referrerURL string) error { - return nil -} - -func (m noopMailer) MagicLinkMail(user *models.User, referrerURL string) error { - return nil -} - -func (m *noopMailer) EmailChangeMail(user *models.User, referrerURL string) error { - return nil -} - -func (m noopMailer) Send(user *models.User, subject, body string, data map[string]interface{}) error { - return nil -} - -func (m noopMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string) (string, error) { - return "", nil -} diff --git a/mailer/template.go b/mailer/template.go deleted file mode 100644 index d2764fff7..000000000 --- a/mailer/template.go +++ /dev/null @@ -1,286 +0,0 @@ -package mailer - -import ( - "fmt" - - "github.com/badoux/checkmail" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/models" - "github.com/netlify/mailme" -) - -// TemplateMailer will send mail and use templates from the site for easy mail styling -type TemplateMailer struct { - SiteURL string - Config *conf.Configuration - Mailer *mailme.Mailer -} - -var configFile = "" - -const defaultInviteMail = `

You have been invited

- -

You have been invited to create a user on {{ .SiteURL }}. Follow this link to accept the invite:

-

Accept the invite

` - -const defaultConfirmationMail = `

Confirm your email

- -

Follow this link to confirm your email:

-

Confirm your email address

` - -const defaultRecoveryMail = `

Reset password

- -

Follow this link to reset the password for your user:

-

Reset password

` - -const defaultMagicLinkMail = `

Magic Link

- -

Follow this link to login:

-

Log In

` - -const defaultEmailChangeMail = `

Confirm email address change

- -

Follow this link to confirm the update of your email address from {{ .Email }} to {{ .NewEmail }}:

-

Change email address

` - -// ValidateEmail returns nil if the email is valid, -// otherwise an error indicating the reason it is invalid -func (m TemplateMailer) ValidateEmail(email string) error { - return checkmail.ValidateFormat(email) -} - -// InviteMail sends a invite mail to a new user -func (m *TemplateMailer) InviteMail(user *models.User, referrerURL string) error { - globalConfig, err := conf.LoadGlobal(configFile) - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - - url, err := getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Invite, "token="+user.ConfirmationToken+"&type=invite"+redirectParam) - if err != nil { - return err - } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, - "Email": user.Email, - "Token": user.ConfirmationToken, - "Data": user.UserMetaData, - } - - return m.Mailer.Mail( - user.GetEmail(), - string(withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited")), - m.Config.Mailer.Templates.Invite, - defaultInviteMail, - data, - ) -} - -// ConfirmationMail sends a signup confirmation mail to a new user -func (m *TemplateMailer) ConfirmationMail(user *models.User, referrerURL string) error { - globalConfig, err := conf.LoadGlobal(configFile) - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - - url, err := getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Confirmation, "token="+user.ConfirmationToken+"&type=signup"+redirectParam) - if err != nil { - return err - } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, - "Email": user.Email, - "Token": user.ConfirmationToken, - "Data": user.UserMetaData, - } - - return m.Mailer.Mail( - user.GetEmail(), - string(withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email")), - m.Config.Mailer.Templates.Confirmation, - defaultConfirmationMail, - data, - ) -} - -// EmailChangeMail sends an email change confirmation mail to a user -func (m *TemplateMailer) EmailChangeMail(user *models.User, referrerURL string) error { - type Email struct { - Address string - Token string - Subject string - Template string - } - emails := []Email{ - { - Address: user.EmailChange, - Token: user.EmailChangeTokenNew, - Subject: string(withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change")), - Template: m.Config.Mailer.Templates.Confirmation, - }, - } - - if m.Config.Mailer.SecureEmailChangeEnabled { - emails = append(emails, Email{ - Address: user.GetEmail(), - Token: user.EmailChangeTokenCurrent, - Subject: string(withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Email Address")), - Template: m.Config.Mailer.Templates.EmailChange, - }) - } - - globalConfig, err := conf.LoadGlobal(configFile) - if err != nil { - return err - } - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - errors := make(chan error) - for _, email := range emails { - url, err := getSiteURL( - referrerURL, - globalConfig.API.ExternalURL, - m.Config.Mailer.URLPaths.EmailChange, - "token="+email.Token+"&type=email_change"+redirectParam, - ) - if err != nil { - return err - } - go func(address, token, template string) { - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, - "Email": user.GetEmail(), - "NewEmail": user.EmailChange, - "Token": token, - "Data": user.UserMetaData, - } - errors <- m.Mailer.Mail( - address, - string(withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change")), - template, - defaultEmailChangeMail, - data, - ) - }(email.Address, email.Token, email.Template) - } - - for i := 0; i < len(emails); i++ { - e := <-errors - if e != nil { - return e - } - } - - return nil -} - -// RecoveryMail sends a password recovery mail -func (m *TemplateMailer) RecoveryMail(user *models.User, referrerURL string) error { - globalConfig, err := conf.LoadGlobal(configFile) - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - - url, err := getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=recovery"+redirectParam) - if err != nil { - return err - } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, - "Email": user.Email, - "Token": user.RecoveryToken, - "Data": user.UserMetaData, - } - - return m.Mailer.Mail( - user.GetEmail(), - string(withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password")), - m.Config.Mailer.Templates.Recovery, - defaultRecoveryMail, - data, - ) -} - -// MagicLinkMail sends a login link mail -func (m *TemplateMailer) MagicLinkMail(user *models.User, referrerURL string) error { - globalConfig, err := conf.LoadGlobal(configFile) - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - - url, err := getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=magiclink"+redirectParam) - if err != nil { - return err - } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, - "Email": user.Email, - "Token": user.RecoveryToken, - "Data": user.UserMetaData, - } - - return m.Mailer.Mail( - user.GetEmail(), - string(withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link")), - m.Config.Mailer.Templates.MagicLink, - defaultMagicLinkMail, - data, - ) -} - -// Send can be used to send one-off emails to users -func (m TemplateMailer) Send(user *models.User, subject, body string, data map[string]interface{}) error { - return m.Mailer.Mail( - user.GetEmail(), - subject, - "", - body, - data, - ) -} - -// GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. -func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string) (string, error) { - globalConfig, err := conf.LoadGlobal(configFile) - - redirectParam := "" - if len(referrerURL) > 0 { - redirectParam = "&redirect_to=" + referrerURL - } - - var url string - switch actionType { - case "magiclink": - url, err = getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=magiclink"+redirectParam) - case "recovery": - url, err = getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=recovery"+redirectParam) - case "invite": - url, err = getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Invite, "token="+user.ConfirmationToken+"&type=invite"+redirectParam) - case "signup": - url, err = getSiteURL(referrerURL, globalConfig.API.ExternalURL, m.Config.Mailer.URLPaths.Confirmation, "token="+user.ConfirmationToken+"&type=signup"+redirectParam) - default: - return "", fmt.Errorf("Invalid email action link type: %s", actionType) - } - - if err != nil { - return "", err - } - - return url, nil -} diff --git a/main.go b/main.go index 519f38031..745519383 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,68 @@ package main import ( - "log" + "context" + "embed" + "os/signal" + "sync" + "syscall" + "time" - "github.com/netlify/gotrue/cmd" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/cmd" + "github.com/supabase/auth/internal/observability" ) +//go:embed migrations/* +var embeddedMigrations embed.FS + +func init() { + logrus.SetFormatter(&logrus.JSONFormatter{}) +} + func main() { - if err := cmd.RootCommand().Execute(); err != nil { - log.Fatal(err) + cmd.EmbeddedMigrations = embeddedMigrations + + execCtx, execCancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + defer execCancel() + + go func() { + <-execCtx.Done() + logrus.Info("received graceful shutdown signal") + }() + + // command is expected to obey the cancellation signal on execCtx and + // block while it is running + if err := cmd.RootCommand().ExecuteContext(execCtx); err != nil { + logrus.WithError(err).Fatal(err) + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Minute) + defer shutdownCancel() + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + // wait for profiler, metrics and trace exporters to shut down gracefully + observability.WaitForCleanup(shutdownCtx) + }() + + cleanupDone := make(chan struct{}) + go func() { + defer close(cleanupDone) + wg.Wait() + }() + + select { + case <-shutdownCtx.Done(): + // cleanup timed out + return + + case <-cleanupDone: + // cleanup finished before timing out + return } } diff --git a/migrations/00_init_auth_schema.up.sql b/migrations/00_init_auth_schema.up.sql new file mode 100644 index 000000000..a040095ae --- /dev/null +++ b/migrations/00_init_auth_schema.up.sql @@ -0,0 +1,88 @@ +-- auth.users definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.users ( + instance_id uuid NULL, + id uuid NOT NULL UNIQUE, + aud varchar(255) NULL, + "role" varchar(255) NULL, + email varchar(255) NULL UNIQUE, + encrypted_password varchar(255) NULL, + confirmed_at timestamptz NULL, + invited_at timestamptz NULL, + confirmation_token varchar(255) NULL, + confirmation_sent_at timestamptz NULL, + recovery_token varchar(255) NULL, + recovery_sent_at timestamptz NULL, + email_change_token varchar(255) NULL, + email_change varchar(255) NULL, + email_change_sent_at timestamptz NULL, + last_sign_in_at timestamptz NULL, + raw_app_meta_data jsonb NULL, + raw_user_meta_data jsonb NULL, + is_super_admin bool NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT users_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS users_instance_id_email_idx ON {{ index .Options "Namespace" }}.users USING btree (instance_id, email); +CREATE INDEX IF NOT EXISTS users_instance_id_idx ON {{ index .Options "Namespace" }}.users USING btree (instance_id); +comment on table {{ index .Options "Namespace" }}.users is 'Auth: Stores user login data within a secure schema.'; + +-- auth.refresh_tokens definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.refresh_tokens ( + instance_id uuid NULL, + id bigserial NOT NULL, + "token" varchar(255) NULL, + user_id varchar(255) NULL, + revoked bool NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT refresh_tokens_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS refresh_tokens_instance_id_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (instance_id); +CREATE INDEX IF NOT EXISTS refresh_tokens_instance_id_user_id_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (instance_id, user_id); +CREATE INDEX IF NOT EXISTS refresh_tokens_token_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (token); +comment on table {{ index .Options "Namespace" }}.refresh_tokens is 'Auth: Store of tokens used to refresh JWT tokens once they expire.'; + +-- auth.instances definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.instances ( + id uuid NOT NULL, + uuid uuid NULL, + raw_base_config text NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT instances_pkey PRIMARY KEY (id) +); +comment on table {{ index .Options "Namespace" }}.instances is 'Auth: Manages users across multiple sites.'; + +-- auth.audit_log_entries definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.audit_log_entries ( + instance_id uuid NULL, + id uuid NOT NULL, + payload json NULL, + created_at timestamptz NULL, + CONSTRAINT audit_log_entries_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS audit_logs_instance_id_idx ON {{ index .Options "Namespace" }}.audit_log_entries USING btree (instance_id); +comment on table {{ index .Options "Namespace" }}.audit_log_entries is 'Auth: Audit trail for user actions.'; + +-- auth.schema_migrations definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.schema_migrations ( + "version" varchar(255) NOT NULL, + CONSTRAINT schema_migrations_pkey PRIMARY KEY ("version") +); +comment on table {{ index .Options "Namespace" }}.schema_migrations is 'Auth: Manages updates to the auth system.'; + +-- Gets the User ID from the request cookie +create or replace function {{ index .Options "Namespace" }}.uid() returns uuid as $$ + select nullif(current_setting('request.jwt.claim.sub', true), '')::uuid; +$$ language sql stable; + +-- Gets the User ID from the request cookie +create or replace function {{ index .Options "Namespace" }}.role() returns text as $$ + select nullif(current_setting('request.jwt.claim.role', true), '')::text; +$$ language sql stable; diff --git a/migrations/20210710035447_alter_users.up.sql b/migrations/20210710035447_alter_users.up.sql index 40bcb5e77..fc8de129d 100644 --- a/migrations/20210710035447_alter_users.up.sql +++ b/migrations/20210710035447_alter_users.up.sql @@ -1,6 +1,6 @@ -- alter user schema -ALTER TABLE auth.users +ALTER TABLE {{ index .Options "Namespace" }}.users ADD COLUMN IF NOT EXISTS phone VARCHAR(15) NULL UNIQUE DEFAULT NULL, ADD COLUMN IF NOT EXISTS phone_confirmed_at timestamptz NULL DEFAULT NULL, ADD COLUMN IF NOT EXISTS phone_change VARCHAR(15) NULL DEFAULT '', @@ -11,9 +11,9 @@ DO $$ BEGIN IF NOT EXISTS(SELECT * FROM information_schema.columns - WHERE table_schema = 'auth' and table_name='users' and column_name='email_confirmed_at') + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='users' and column_name='email_confirmed_at') THEN - ALTER TABLE "auth"."users" RENAME COLUMN "confirmed_at" TO "email_confirmed_at"; + ALTER TABLE "{{ index .Options "Namespace" }}"."users" RENAME COLUMN "confirmed_at" TO "email_confirmed_at"; END IF; END $$; diff --git a/migrations/20210722035447_adds_confirmed_at.up.sql b/migrations/20210722035447_adds_confirmed_at.up.sql index ac5f331aa..aabd42e98 100644 --- a/migrations/20210722035447_adds_confirmed_at.up.sql +++ b/migrations/20210722035447_adds_confirmed_at.up.sql @@ -1,4 +1,4 @@ -- adds confirmed at -ALTER TABLE auth.users +ALTER TABLE {{ index .Options "Namespace" }}.users ADD COLUMN IF NOT EXISTS confirmed_at timestamptz GENERATED ALWAYS AS (LEAST (users.email_confirmed_at, users.phone_confirmed_at)) STORED; diff --git a/migrations/20210730183235_add_email_change_confirmed.up.sql b/migrations/20210730183235_add_email_change_confirmed.up.sql index 9887f4597..dc92c9cfa 100644 --- a/migrations/20210730183235_add_email_change_confirmed.up.sql +++ b/migrations/20210730183235_add_email_change_confirmed.up.sql @@ -1,6 +1,6 @@ -- adds email_change_confirmed -ALTER TABLE auth.users +ALTER TABLE {{ index .Options "Namespace" }}.users ADD COLUMN IF NOT EXISTS email_change_token_current varchar(255) null DEFAULT '', ADD COLUMN IF NOT EXISTS email_change_confirm_status smallint DEFAULT 0 CHECK (email_change_confirm_status >= 0 AND email_change_confirm_status <= 2); @@ -8,8 +8,8 @@ DO $$ BEGIN IF NOT EXISTS(SELECT * FROM information_schema.columns - WHERE table_schema = 'auth' and table_name='users' and column_name='email_change_token_new') + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='users' and column_name='email_change_token_new') THEN - ALTER TABLE "auth"."users" RENAME COLUMN "email_change_token" TO "email_change_token_new"; + ALTER TABLE "{{ index .Options "Namespace" }}"."users" RENAME COLUMN "email_change_token" TO "email_change_token_new"; END IF; END $$; diff --git a/migrations/20210909172000_create_identities_table.up.sql b/migrations/20210909172000_create_identities_table.up.sql index 10b1f2b90..2f3a53570 100644 --- a/migrations/20210909172000_create_identities_table.up.sql +++ b/migrations/20210909172000_create_identities_table.up.sql @@ -1,6 +1,6 @@ -- adds identities table -CREATE TABLE IF NOT EXISTS auth.identities ( +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.identities ( id text NOT NULL, user_id uuid NOT NULL, identity_data JSONB NOT NULL, @@ -9,6 +9,6 @@ CREATE TABLE IF NOT EXISTS auth.identities ( created_at timestamptz NULL, updated_at timestamptz NULL, CONSTRAINT identities_pkey PRIMARY KEY (provider, id), - CONSTRAINT identities_user_id_fkey FOREIGN KEY (user_id) REFERENCES auth.users(id) ON DELETE CASCADE + CONSTRAINT identities_user_id_fkey FOREIGN KEY (user_id) REFERENCES {{ index .Options "Namespace" }}.users(id) ON DELETE CASCADE ); -COMMENT ON TABLE auth.identities is 'Auth: Stores identities associated to a user.'; +COMMENT ON TABLE {{ index .Options "Namespace" }}.identities is 'Auth: Stores identities associated to a user.'; diff --git a/migrations/20210927181326_add_refresh_token_parent.up.sql b/migrations/20210927181326_add_refresh_token_parent.up.sql index 2a9ff0e3f..a2b1c73fe 100644 --- a/migrations/20210927181326_add_refresh_token_parent.up.sql +++ b/migrations/20210927181326_add_refresh_token_parent.up.sql @@ -1,8 +1,24 @@ -- adds parent column -ALTER TABLE auth.refresh_tokens -ADD COLUMN IF NOT EXISTS parent varchar(255) NULL, -ADD CONSTRAINT refresh_tokens_token_unique UNIQUE ("token"), -ADD CONSTRAINT refresh_tokens_parent_fkey FOREIGN KEY (parent) REFERENCES auth.refresh_tokens("token"); +ALTER TABLE {{ index .Options "Namespace" }}.refresh_tokens +ADD COLUMN IF NOT EXISTS parent varchar(255) NULL; + +DO $$ +BEGIN + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='refresh_tokens' and constraint_name='refresh_tokens_token_unique') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."refresh_tokens" ADD CONSTRAINT refresh_tokens_token_unique UNIQUE ("token"); + END IF; + + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='refresh_tokens' and constraint_name='refresh_tokens_parent_fkey') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."refresh_tokens" ADD CONSTRAINT refresh_tokens_parent_fkey FOREIGN KEY (parent) REFERENCES {{ index .Options "Namespace" }}.refresh_tokens("token"); + END IF; + + CREATE INDEX IF NOT EXISTS refresh_tokens_parent_idx ON "{{ index .Options "Namespace" }}"."refresh_tokens" USING btree (parent); +END $$; -CREATE INDEX IF NOT EXISTS refresh_tokens_parent_idx ON refresh_tokens USING btree (parent); diff --git a/migrations/20211122151130_create_user_id_idx.up.sql b/migrations/20211122151130_create_user_id_idx.up.sql index cfa3809ca..d259aae6b 100644 --- a/migrations/20211122151130_create_user_id_idx.up.sql +++ b/migrations/20211122151130_create_user_id_idx.up.sql @@ -1,3 +1,3 @@ -- create index on identities.user_id -CREATE INDEX IF NOT EXISTS identities_user_id_idx ON identities using btree (user_id); +CREATE INDEX IF NOT EXISTS identities_user_id_idx ON "{{ index .Options "Namespace" }}".identities using btree (user_id); diff --git a/migrations/20211124214934_update_auth_functions.up.sql b/migrations/20211124214934_update_auth_functions.up.sql index a03d8cf8a..2fb784bbb 100644 --- a/migrations/20211124214934_update_auth_functions.up.sql +++ b/migrations/20211124214934_update_auth_functions.up.sql @@ -1,34 +1,34 @@ -- update auth functions -create or replace function auth.uid() +create or replace function {{ index .Options "Namespace" }}.uid() returns uuid language sql stable as $$ select - coalesce( - current_setting('request.jwt.claim.sub', true), - (current_setting('request.jwt.claims', true)::jsonb ->> 'sub') - )::uuid + coalesce( + current_setting('request.jwt.claim.sub', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'sub') + )::uuid $$; -create or replace function auth.role() +create or replace function {{ index .Options "Namespace" }}.role() returns text language sql stable as $$ select - coalesce( - current_setting('request.jwt.claim.role', true), - (current_setting('request.jwt.claims', true)::jsonb ->> 'role') - )::text + coalesce( + current_setting('request.jwt.claim.role', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'role') + )::text $$; -create or replace function auth.email() +create or replace function {{ index .Options "Namespace" }}.email() returns text language sql stable as $$ select - coalesce( - current_setting('request.jwt.claim.email', true), - (current_setting('request.jwt.claims', true)::jsonb ->> 'email') - )::text + coalesce( + current_setting('request.jwt.claim.email', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'email') + )::text $$; diff --git a/migrations/20211202183645_update_auth_uid.up.sql b/migrations/20211202183645_update_auth_uid.up.sql index ad9db0b0c..3ecadfd5c 100644 --- a/migrations/20211202183645_update_auth_uid.up.sql +++ b/migrations/20211202183645_update_auth_uid.up.sql @@ -1,6 +1,6 @@ -- update auth.uid() -create or replace function auth.uid() +create or replace function {{ index .Options "Namespace" }}.uid() returns uuid language sql stable as $$ diff --git a/migrations/20220114185221_update_user_idx.up.sql b/migrations/20220114185221_update_user_idx.up.sql index 7572e047d..02fe76a07 100644 --- a/migrations/20220114185221_update_user_idx.up.sql +++ b/migrations/20220114185221_update_user_idx.up.sql @@ -1,4 +1,4 @@ -- updates users_instance_id_email_idx definition DROP INDEX IF EXISTS users_instance_id_email_idx; -CREATE INDEX IF NOT EXISTS users_instance_id_email_idx on users using btree (instance_id, lower(email)); +CREATE INDEX IF NOT EXISTS users_instance_id_email_idx on "{{ index .Options "Namespace" }}".users using btree (instance_id, lower(email)); diff --git a/migrations/20220114185340_add_banned_until.up.sql b/migrations/20220114185340_add_banned_until.up.sql index c2af132b0..7530a7c1b 100644 --- a/migrations/20220114185340_add_banned_until.up.sql +++ b/migrations/20220114185340_add_banned_until.up.sql @@ -1,4 +1,4 @@ -- adds banned_until column -ALTER TABLE auth.users +ALTER TABLE {{ index .Options "Namespace" }}.users ADD COLUMN IF NOT EXISTS banned_until timestamptz NULL; diff --git a/migrations/20220224000811_update_auth_functions.up.sql b/migrations/20220224000811_update_auth_functions.up.sql new file mode 100644 index 000000000..4be423739 --- /dev/null +++ b/migrations/20220224000811_update_auth_functions.up.sql @@ -0,0 +1,34 @@ +-- update auth functions + +create or replace function {{ index .Options "Namespace" }}.uid() +returns uuid +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.sub', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'sub') + )::uuid +$$; + +create or replace function {{ index .Options "Namespace" }}.role() +returns text +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.role', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'role') + )::text +$$; + +create or replace function {{ index .Options "Namespace" }}.email() +returns text +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.email', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'email') + )::text +$$; diff --git a/migrations/20220323170000_add_user_reauthentication.up.sql b/migrations/20220323170000_add_user_reauthentication.up.sql new file mode 100644 index 000000000..277dbdb5a --- /dev/null +++ b/migrations/20220323170000_add_user_reauthentication.up.sql @@ -0,0 +1,5 @@ +-- adds reauthentication_token and reauthentication_sent_at + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS reauthentication_token varchar(255) null default '', +ADD COLUMN IF NOT EXISTS reauthentication_sent_at timestamptz null default null; diff --git a/migrations/20220429102000_add_unique_idx.up.sql b/migrations/20220429102000_add_unique_idx.up.sql new file mode 100644 index 000000000..9d7644df2 --- /dev/null +++ b/migrations/20220429102000_add_unique_idx.up.sql @@ -0,0 +1,14 @@ +-- add partial unique indices to confirmation_token, recovery_token, email_change_token_current, email_change_token_new, phone_change_token, reauthentication_token +-- ignores partial unique index creation on fields which contain empty strings, whitespaces or purely numeric otps + +DROP INDEX IF EXISTS confirmation_token_idx; +DROP INDEX IF EXISTS recovery_token_idx; +DROP INDEX IF EXISTS email_change_token_current_idx; +DROP INDEX IF EXISTS email_change_token_new_idx; +DROP INDEX IF EXISTS reauthentication_token_idx; + +CREATE UNIQUE INDEX IF NOT EXISTS confirmation_token_idx ON {{ index .Options "Namespace" }}.users USING btree (confirmation_token) WHERE confirmation_token !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS recovery_token_idx ON {{ index .Options "Namespace" }}.users USING btree (recovery_token) WHERE recovery_token !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS email_change_token_current_idx ON {{ index .Options "Namespace" }}.users USING btree (email_change_token_current) WHERE email_change_token_current !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS email_change_token_new_idx ON {{ index .Options "Namespace" }}.users USING btree (email_change_token_new) WHERE email_change_token_new !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS reauthentication_token_idx ON {{ index .Options "Namespace" }}.users USING btree (reauthentication_token) WHERE reauthentication_token !~ '^[0-9 ]*$'; diff --git a/migrations/20220531120530_add_auth_jwt_function.up.sql b/migrations/20220531120530_add_auth_jwt_function.up.sql new file mode 100644 index 000000000..11f84e85f --- /dev/null +++ b/migrations/20220531120530_add_auth_jwt_function.up.sql @@ -0,0 +1,16 @@ +-- add auth.jwt function + +comment on function {{ index .Options "Namespace" }}.uid() is 'Deprecated. Use auth.jwt() -> ''sub'' instead.'; +comment on function {{ index .Options "Namespace" }}.role() is 'Deprecated. Use auth.jwt() -> ''role'' instead.'; +comment on function {{ index .Options "Namespace" }}.email() is 'Deprecated. Use auth.jwt() -> ''email'' instead.'; + +create or replace function {{ index .Options "Namespace" }}.jwt() +returns jsonb +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim', true), ''), + nullif(current_setting('request.jwt.claims', true), '') + )::jsonb +$$; diff --git a/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql b/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql new file mode 100644 index 000000000..a1a66b4b3 --- /dev/null +++ b/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql @@ -0,0 +1,3 @@ +-- Add IP Address to audit log +ALTER TABLE {{ index .Options "Namespace" }}.audit_log_entries +ADD COLUMN IF NOT EXISTS ip_address VARCHAR(64) NOT NULL DEFAULT ''; diff --git a/migrations/20220811173540_add_sessions_table.up.sql b/migrations/20220811173540_add_sessions_table.up.sql new file mode 100644 index 000000000..c16ef3ca5 --- /dev/null +++ b/migrations/20220811173540_add_sessions_table.up.sql @@ -0,0 +1,23 @@ +-- Add session_id column to refresh_tokens table +create table if not exists {{ index .Options "Namespace" }}.sessions ( + id uuid not null, + user_id uuid not null, + created_at timestamptz null, + updated_at timestamptz null, + constraint sessions_pkey primary key (id), + constraint sessions_user_id_fkey foreign key (user_id) references {{ index .Options "Namespace" }}.users(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.sessions is 'Auth: Stores session data associated to a user.'; + +alter table {{ index .Options "Namespace" }}.refresh_tokens +add column if not exists session_id uuid null; + +do $$ +begin + if not exists(select * + from information_schema.constraint_column_usage + where table_schema = '{{ index .Options "Namespace" }}' and table_name='sessions' and constraint_name='refresh_tokens_session_id_fkey') + then + alter table "{{ index .Options "Namespace" }}"."refresh_tokens" add constraint refresh_tokens_session_id_fkey foreign key (session_id) references {{ index .Options "Namespace" }}.sessions(id) on delete cascade; + end if; +END $$; diff --git a/migrations/20221003041349_add_mfa_schema.up.sql b/migrations/20221003041349_add_mfa_schema.up.sql new file mode 100644 index 000000000..a44654aed --- /dev/null +++ b/migrations/20221003041349_add_mfa_schema.up.sql @@ -0,0 +1,50 @@ +-- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 +do $$ begin + create type factor_type as enum('totp', 'webauthn'); + create type factor_status as enum('unverified', 'verified'); + create type aal_level as enum('aal1', 'aal2', 'aal3'); +exception + when duplicate_object then null; +end $$; + +-- auth.mfa_factors definition +create table if not exists {{ index .Options "Namespace" }}.mfa_factors( + id uuid not null, + user_id uuid not null, + friendly_name text null, + factor_type factor_type not null, + status factor_status not null, + created_at timestamptz not null, + updated_at timestamptz not null, + secret text null, + constraint mfa_factors_pkey primary key(id), + constraint mfa_factors_user_id_fkey foreign key (user_id) references {{ index .Options "Namespace" }}.users(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_factors is 'auth: stores metadata about factors'; + +create unique index if not exists mfa_factors_user_friendly_name_unique on {{ index .Options "Namespace" }}.mfa_factors (friendly_name, user_id) where trim(friendly_name) <> ''; + +-- auth.mfa_challenges definition +create table if not exists {{ index .Options "Namespace" }}.mfa_challenges( + id uuid not null, + factor_id uuid not null, + created_at timestamptz not null, + verified_at timestamptz null, + ip_address inet not null, + constraint mfa_challenges_pkey primary key (id), + constraint mfa_challenges_auth_factor_id_fkey foreign key (factor_id) references {{ index .Options "Namespace" }}.mfa_factors(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_challenges is 'auth: stores metadata about challenge requests made'; + + + +-- add factor_id and amr claims to session +create table if not exists {{ index .Options "Namespace" }}.mfa_amr_claims( + session_id uuid not null, + created_at timestamptz not null, + updated_at timestamptz not null, + authentication_method text not null, + constraint mfa_amr_claims_session_id_authentication_method_pkey unique(session_id, authentication_method), + constraint mfa_amr_claims_session_id_fkey foreign key(session_id) references {{ index .Options "Namespace" }}.sessions(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_amr_claims is 'auth: stores authenticator method reference claims for multi factor authentication'; diff --git a/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql b/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql new file mode 100644 index 000000000..cc8a2096d --- /dev/null +++ b/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql @@ -0,0 +1,3 @@ +-- add factor_id to sessions + alter table {{ index .Options "Namespace" }}.sessions add column if not exists factor_id uuid null; + alter table {{ index .Options "Namespace" }}.sessions add column if not exists aal aal_level null; diff --git a/migrations/20221011041400_add_mfa_indexes.up.sql b/migrations/20221011041400_add_mfa_indexes.up.sql new file mode 100644 index 000000000..def57a299 --- /dev/null +++ b/migrations/20221011041400_add_mfa_indexes.up.sql @@ -0,0 +1,19 @@ +alter table {{ index .Options "Namespace" }}.mfa_amr_claims + add column if not exists id uuid not null; + +do $$ +begin + if not exists + (select constraint_name + from information_schema.table_constraints + where table_schema = '{{ index .Options "Namespace" }}' + and table_name = 'mfa_amr_claims' + and constraint_name = 'amr_id_pk') + then + alter table {{ index .Options "Namespace" }}.mfa_amr_claims add constraint amr_id_pk primary key(id); + end if; +end $$; + +create index if not exists user_id_created_at_idx on {{ index .Options "Namespace" }}.sessions (user_id, created_at); +create index if not exists factor_id_created_at_idx on {{ index .Options "Namespace" }}.mfa_factors (user_id, created_at); + diff --git a/migrations/20221020193600_add_sessions_user_id_index.up.sql b/migrations/20221020193600_add_sessions_user_id_index.up.sql new file mode 100644 index 000000000..f5ba04257 --- /dev/null +++ b/migrations/20221020193600_add_sessions_user_id_index.up.sql @@ -0,0 +1,2 @@ +create index if not exists sessions_user_id_idx on {{ index .Options "Namespace" }}.sessions (user_id); + diff --git a/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql b/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql new file mode 100644 index 000000000..0c47d4a75 --- /dev/null +++ b/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql @@ -0,0 +1 @@ +create index if not exists refresh_tokens_session_id_revoked_idx on {{ index .Options "Namespace" }}.refresh_tokens (session_id, revoked); diff --git a/migrations/20221021082433_add_saml.up.sql b/migrations/20221021082433_add_saml.up.sql new file mode 100644 index 000000000..30ac3d03a --- /dev/null +++ b/migrations/20221021082433_add_saml.up.sql @@ -0,0 +1,90 @@ +-- Multi-instance mode (see auth.instances) table intentionally not supported and ignored. + +create table if not exists {{ index .Options "Namespace" }}.sso_providers ( + id uuid not null, + resource_id text null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + constraint "resource_id not empty" check (resource_id = null or char_length(resource_id) > 0) +); + +comment on table {{ index .Options "Namespace" }}.sso_providers is 'Auth: Manages SSO identity provider information; see saml_providers for SAML.'; +comment on column {{ index .Options "Namespace" }}.sso_providers.resource_id is 'Auth: Uniquely identifies a SSO provider according to a user-chosen resource ID (case insensitive), useful in infrastructure as code.'; + +create unique index if not exists sso_providers_resource_id_idx on {{ index .Options "Namespace" }}.sso_providers (lower(resource_id)); + +create table if not exists {{ index .Options "Namespace" }}.sso_domains ( + id uuid not null, + sso_provider_id uuid not null, + domain text not null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "domain not empty" check (char_length(domain) > 0) +); + +create index if not exists sso_domains_sso_provider_id_idx on {{ index .Options "Namespace" }}.sso_domains (sso_provider_id); +create unique index if not exists sso_domains_domain_idx on {{ index .Options "Namespace" }}.sso_domains (lower(domain)); + +comment on table {{ index .Options "Namespace" }}.sso_domains is 'Auth: Manages SSO email address domain mapping to an SSO Identity Provider.'; + +create table if not exists {{ index .Options "Namespace" }}.saml_providers ( + id uuid not null, + sso_provider_id uuid not null, + entity_id text not null unique, + metadata_xml text not null, + metadata_url text null, + attribute_mapping jsonb null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "metadata_xml not empty" check (char_length(metadata_xml) > 0), + constraint "metadata_url not empty" check (metadata_url = null or char_length(metadata_url) > 0), + constraint "entity_id not empty" check (char_length(entity_id) > 0) +); + +create index if not exists saml_providers_sso_provider_id_idx on {{ index .Options "Namespace" }}.saml_providers (sso_provider_id); + +comment on table {{ index .Options "Namespace" }}.saml_providers is 'Auth: Manages SAML Identity Provider connections.'; + +create table if not exists {{ index .Options "Namespace" }}.saml_relay_states ( + id uuid not null, + sso_provider_id uuid not null, + request_id text not null, + for_email text null, + redirect_to text null, + from_ip_address inet null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "request_id not empty" check(char_length(request_id) > 0) +); + +create index if not exists saml_relay_states_sso_provider_id_idx on {{ index .Options "Namespace" }}.saml_relay_states (sso_provider_id); +create index if not exists saml_relay_states_for_email_idx on {{ index .Options "Namespace" }}.saml_relay_states (for_email); + +comment on table {{ index .Options "Namespace" }}.saml_relay_states is 'Auth: Contains SAML Relay State information for each Service Provider initiated login.'; + +create table if not exists {{ index .Options "Namespace" }}.sso_sessions ( + id uuid not null, + session_id uuid not null, + sso_provider_id uuid null, + not_before timestamptz null, + not_after timestamptz null, + idp_initiated boolean default false, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (session_id) references {{ index .Options "Namespace" }}.sessions (id) on delete cascade, + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade +); + +create index if not exists sso_sessions_session_id_idx on {{ index .Options "Namespace" }}.sso_sessions (session_id); +create index if not exists sso_sessions_sso_provider_id_idx on {{ index .Options "Namespace" }}.sso_sessions (sso_provider_id); + +comment on table {{ index .Options "Namespace" }}.sso_sessions is 'Auth: A session initiated by an SSO Identity Provider'; + diff --git a/migrations/20221027105023_add_identities_user_id_idx.up.sql b/migrations/20221027105023_add_identities_user_id_idx.up.sql new file mode 100644 index 000000000..12e7aa556 --- /dev/null +++ b/migrations/20221027105023_add_identities_user_id_idx.up.sql @@ -0,0 +1 @@ +create index if not exists identities_user_id_idx on {{ index .Options "Namespace" }}.identities using btree (user_id); diff --git a/migrations/20221114143122_add_session_not_after_column.up.sql b/migrations/20221114143122_add_session_not_after_column.up.sql new file mode 100644 index 000000000..c729911c1 --- /dev/null +++ b/migrations/20221114143122_add_session_not_after_column.up.sql @@ -0,0 +1,4 @@ +alter table only {{ index .Options "Namespace" }}.sessions + add column if not exists not_after timestamptz; + +comment on column {{ index .Options "Namespace" }}.sessions.not_after is 'Auth: Not after is a nullable column that contains a timestamp after which the session should be regarded as expired.'; diff --git a/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql b/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql new file mode 100644 index 000000000..62d207848 --- /dev/null +++ b/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql @@ -0,0 +1,2 @@ +alter table only {{ index .Options "Namespace" }}.refresh_tokens + drop constraint refresh_tokens_parent_fkey; diff --git a/migrations/20221125140132_backfill_email_identity.up.sql b/migrations/20221125140132_backfill_email_identity.up.sql new file mode 100644 index 000000000..cd06425c0 --- /dev/null +++ b/migrations/20221125140132_backfill_email_identity.up.sql @@ -0,0 +1,11 @@ +-- backfill the auth.identities column by adding an email identity +-- for all auth.users with an email and password + +do $$ +begin + insert into {{ index .Options "Namespace" }}.identities (id, user_id, identity_data, provider, last_sign_in_at, created_at, updated_at) + select id, id as user_id, jsonb_build_object('sub', id, 'email', email) as identity_data, 'email' as provider, null as last_sign_in_at, '2022-11-25' as created_at, '2022-11-25' as updated_at + from {{ index .Options "Namespace" }}.users as users + where encrypted_password != '' and email is not null and not exists(select user_id from {{ index .Options "Namespace" }}.identities where user_id = users.id); +end; +$$; diff --git a/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql b/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql new file mode 100644 index 000000000..19ec79e9e --- /dev/null +++ b/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql @@ -0,0 +1,13 @@ +-- previous backfill migration left last_sign_in_at to be null, which broke some projects + +do $$ +begin +update {{ index .Options "Namespace" }}.identities + set last_sign_in_at = '2022-11-25' + where + last_sign_in_at is null and + created_at = '2022-11-25' and + updated_at = '2022-11-25' and + provider = 'email' and + id = user_id::text; +end $$; diff --git a/migrations/20221215195500_modify_users_email_unique_index.up.sql b/migrations/20221215195500_modify_users_email_unique_index.up.sql new file mode 100644 index 000000000..c12de04f6 --- /dev/null +++ b/migrations/20221215195500_modify_users_email_unique_index.up.sql @@ -0,0 +1,23 @@ +-- this change is relatively temporary +-- it is meant to keep database consistency guarantees until there is proper +-- introduction of account linking / merging / delinking APIs, at which point +-- rows in the users table will allow duplicates but with programmatic control + +alter table only {{ index .Options "Namespace" }}.users + add column if not exists is_sso_user boolean not null default false; + +comment on column {{ index .Options "Namespace" }}.users.is_sso_user is 'Auth: Set this column to true when the account comes from SSO. These accounts can have duplicate emails.'; + +do $$ +begin + alter table only {{ index .Options "Namespace" }}.users + drop constraint if exists users_email_key; +exception +-- dependent object: https://www.postgresql.org/docs/current/errcodes-appendix.html +when SQLSTATE '2BP01' then + raise notice 'Unable to drop users_email_key constraint due to dependent objects, please resolve this manually or SSO may not work'; +end $$; + +create unique index if not exists users_email_partial_key on {{ index .Options "Namespace" }}.users (email) where (is_sso_user = false); + +comment on index {{ index .Options "Namespace" }}.users_email_partial_key is 'Auth: A partial unique index that applies only when is_sso_user is false'; diff --git a/migrations/20221215195800_add_identities_email_column.up.sql b/migrations/20221215195800_add_identities_email_column.up.sql new file mode 100644 index 000000000..eb60334c4 --- /dev/null +++ b/migrations/20221215195800_add_identities_email_column.up.sql @@ -0,0 +1,18 @@ +do $$ +begin + update + {{ index .Options "Namespace" }}.identities as identities + set + identity_data = identity_data || jsonb_build_object('email', (select email from {{ index .Options "Namespace" }}.users where id = identities.user_id)), + updated_at = '2022-11-25' + where identities.provider = 'email' and identity_data->>'email' is null; +end $$; + +alter table only {{ index .Options "Namespace" }}.identities + add column if not exists email text generated always as (lower(identity_data->>'email')) stored; + +comment on column {{ index .Options "Namespace" }}.identities.email is 'Auth: Email is a generated column that references the optional email property in the identity_data'; + +create index if not exists identities_email_idx on {{ index .Options "Namespace" }}.identities (email text_pattern_ops); + +comment on index {{ index .Options "Namespace" }}.identities_email_idx is 'Auth: Ensures indexed queries on the email column'; diff --git a/migrations/20221215195900_remove_sso_sessions.up.sql b/migrations/20221215195900_remove_sso_sessions.up.sql new file mode 100644 index 000000000..228302d01 --- /dev/null +++ b/migrations/20221215195900_remove_sso_sessions.up.sql @@ -0,0 +1,3 @@ +-- sso_sessions is not used as all of the necessary data is in sessions +drop table if exists {{ index .Options "Namespace" }}.sso_sessions; + diff --git a/migrations/20230116124310_alter_phone_type.up.sql b/migrations/20230116124310_alter_phone_type.up.sql new file mode 100644 index 000000000..fa846dbb0 --- /dev/null +++ b/migrations/20230116124310_alter_phone_type.up.sql @@ -0,0 +1,14 @@ +-- alter phone field column type to accomodate for soft deletion + +do $$ +begin + alter table {{ index .Options "Namespace" }}.users + alter column phone type text, + alter column phone_change type text; +exception + -- SQLSTATE errcodes https://www.postgresql.org/docs/current/errcodes-appendix.html + when SQLSTATE '0A000' then + raise notice 'Unable to change data type of phone, phone_change columns due to use by a view or rule'; + when SQLSTATE '2BP01' then + raise notice 'Unable to change data type of phone, phone_change columns due to dependent objects'; +end $$; diff --git a/migrations/20230116124412_add_deleted_at.up.sql b/migrations/20230116124412_add_deleted_at.up.sql new file mode 100644 index 000000000..999abaa88 --- /dev/null +++ b/migrations/20230116124412_add_deleted_at.up.sql @@ -0,0 +1,4 @@ +-- adds deleted_at column to auth.users + +alter table {{ index .Options "Namespace" }}.users +add column if not exists deleted_at timestamptz null; diff --git a/migrations/20230131181311_backfill_invite_identities.up.sql b/migrations/20230131181311_backfill_invite_identities.up.sql new file mode 100644 index 000000000..2fcb35828 --- /dev/null +++ b/migrations/20230131181311_backfill_invite_identities.up.sql @@ -0,0 +1,9 @@ +-- backfills the missing email identity for invited users + +do $$ +begin + insert into {{ index .Options "Namespace" }}.identities (id, user_id, identity_data, provider, last_sign_in_at, created_at, updated_at) + select id, id as user_id, jsonb_build_object('sub', id, 'email', email) as identity_data, 'email' as provider, null as last_sign_in_at, '2023-01-25' as created_at, '2023-01-25' as updated_at + from {{ index .Options "Namespace" }}.users as users + where invited_at is not null and not exists (select user_id from {{ index .Options "Namespace" }}.identities where user_id = users.id and provider = 'email'); +end $$; diff --git a/migrations/20230322519590_add_flow_state_table.up.sql b/migrations/20230322519590_add_flow_state_table.up.sql new file mode 100644 index 000000000..a8842e5b0 --- /dev/null +++ b/migrations/20230322519590_add_flow_state_table.up.sql @@ -0,0 +1,20 @@ +-- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 +do $$ begin + create type code_challenge_method as enum('s256', 'plain'); +exception + when duplicate_object then null; +end $$; +create table if not exists {{ index .Options "Namespace" }}.flow_state( + id uuid primary key, + user_id uuid null, + auth_code text not null, + code_challenge_method code_challenge_method not null, + code_challenge text not null, + provider_type text not null, + provider_access_token text null, + provider_refresh_token text null, + created_at timestamptz null, + updated_at timestamptz null +); +create index if not exists idx_auth_code on {{ index .Options "Namespace" }}.flow_state(auth_code); +comment on table {{ index .Options "Namespace" }}.flow_state is 'stores metadata for pkce logins'; diff --git a/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql b/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql new file mode 100644 index 000000000..e83af8566 --- /dev/null +++ b/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql @@ -0,0 +1,6 @@ +alter table {{index .Options "Namespace" }}.flow_state +add column if not exists authentication_method text not null; +create index if not exists idx_user_id_auth_method on {{index .Options "Namespace" }}.flow_state (user_id, authentication_method); + +-- Update comment as we have generalized the table +comment on table {{ index .Options "Namespace" }}.flow_state is 'stores metadata for pkce logins'; diff --git a/migrations/20230411005111_remove_duplicate_idx.up.sql b/migrations/20230411005111_remove_duplicate_idx.up.sql new file mode 100644 index 000000000..dc23931c7 --- /dev/null +++ b/migrations/20230411005111_remove_duplicate_idx.up.sql @@ -0,0 +1 @@ +drop index if exists {{index .Options "Namespace" }}.refresh_tokens_token_idx; diff --git a/migrations/20230508135423_add_cleanup_indexes.up.sql b/migrations/20230508135423_add_cleanup_indexes.up.sql new file mode 100644 index 000000000..162acee15 --- /dev/null +++ b/migrations/20230508135423_add_cleanup_indexes.up.sql @@ -0,0 +1,17 @@ +-- Indexes used for cleaning up old or stale objects. + +create index if not exists + refresh_tokens_updated_at_idx + on {{ index .Options "Namespace" }}.refresh_tokens (updated_at desc); + +create index if not exists + flow_state_created_at_idx + on {{ index .Options "Namespace" }}.flow_state (created_at desc); + +create index if not exists + saml_relay_states_created_at_idx + on {{ index .Options "Namespace" }}.saml_relay_states (created_at desc); + +create index if not exists + sessions_not_after_idx + on {{ index .Options "Namespace" }}.sessions (not_after desc); diff --git a/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql b/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql new file mode 100644 index 000000000..667d5020b --- /dev/null +++ b/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql @@ -0,0 +1,5 @@ +-- Index used to clean up mfa challenges + +create index if not exists + mfa_challenge_created_at_idx + on {{ index .Options "Namespace" }}.mfa_challenges (created_at desc); diff --git a/migrations/20230818113222_add_flow_state_to_relay_state.up.sql b/migrations/20230818113222_add_flow_state_to_relay_state.up.sql new file mode 100644 index 000000000..f940e706c --- /dev/null +++ b/migrations/20230818113222_add_flow_state_to_relay_state.up.sql @@ -0,0 +1 @@ +alter table {{ index .Options "Namespace" }}.saml_relay_states add column if not exists flow_state_id uuid references {{ index .Options "Namespace" }}.flow_state(id) on delete cascade default null; diff --git a/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql b/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql new file mode 100644 index 000000000..805c97cb8 --- /dev/null +++ b/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql @@ -0,0 +1 @@ +create index if not exists mfa_factors_user_id_idx on {{ index .Options "Namespace" }}.mfa_factors(user_id); diff --git a/migrations/20231027141322_add_session_refresh_columns.up.sql b/migrations/20231027141322_add_session_refresh_columns.up.sql new file mode 100644 index 000000000..79efba9bc --- /dev/null +++ b/migrations/20231027141322_add_session_refresh_columns.up.sql @@ -0,0 +1,4 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists refreshed_at timestamp without time zone, + add column if not exists user_agent text, + add column if not exists ip inet; diff --git a/migrations/20231114161723_add_sessions_tag.up.sql b/migrations/20231114161723_add_sessions_tag.up.sql new file mode 100644 index 000000000..7acf1bb9d --- /dev/null +++ b/migrations/20231114161723_add_sessions_tag.up.sql @@ -0,0 +1,2 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists tag text; diff --git a/migrations/20231117164230_add_id_pkey_identities.up.sql b/migrations/20231117164230_add_id_pkey_identities.up.sql new file mode 100644 index 000000000..31ed280d3 --- /dev/null +++ b/migrations/20231117164230_add_id_pkey_identities.up.sql @@ -0,0 +1,29 @@ +do $$ +begin + if not exists(select * + from information_schema.columns + where table_schema = '{{ index .Options "Namespace" }}' and table_name='identities' and column_name='provider_id') + then + alter table if exists {{ index .Options "Namespace" }}.identities + rename column id to provider_id; + end if; +end$$; + +alter table if exists {{ index .Options "Namespace" }}.identities + drop constraint if exists identities_pkey, + add column if not exists id uuid default gen_random_uuid() primary key; + +do $$ +begin + if not exists + (select constraint_name + from information_schema.table_constraints + where table_schema = '{{ index .Options "Namespace" }}' + and table_name = 'identities' + and constraint_name = 'identities_provider_id_provider_unique') + then + alter table if exists {{ index .Options "Namespace" }}.identities + add constraint identities_provider_id_provider_unique + unique(provider_id, provider); + end if; +end $$; diff --git a/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql b/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql new file mode 100644 index 000000000..169ec37b2 --- /dev/null +++ b/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql @@ -0,0 +1,7 @@ +do $$ +begin + if exists (select from information_schema.columns where table_schema = '{{ index .Options "Namespace" }}' and table_name = 'saml_relay_states' and column_name = 'from_ip_address') then + alter table {{ index .Options "Namespace" }}.saml_relay_states drop column from_ip_address; + end if; +end +$$; diff --git a/migrations/20240214120130_add_is_anonymous_column.up.sql b/migrations/20240214120130_add_is_anonymous_column.up.sql new file mode 100644 index 000000000..6ef963f57 --- /dev/null +++ b/migrations/20240214120130_add_is_anonymous_column.up.sql @@ -0,0 +1,8 @@ +do $$ +begin + alter table {{ index .Options "Namespace" }}.users + add column if not exists is_anonymous boolean not null default false; + + create index if not exists users_is_anonymous_idx on {{ index .Options "Namespace" }}.users using btree (is_anonymous); +end +$$; diff --git a/migrations/20240306115329_add_issued_at_to_flow_state.up.sql b/migrations/20240306115329_add_issued_at_to_flow_state.up.sql new file mode 100644 index 000000000..d6eff157a --- /dev/null +++ b/migrations/20240306115329_add_issued_at_to_flow_state.up.sql @@ -0,0 +1,3 @@ +do $$ begin +alter table {{ index .Options "Namespace" }}.flow_state add column if not exists auth_code_issued_at timestamptz null; +end $$ diff --git a/migrations/20240314092811_add_saml_name_id_format.up.sql b/migrations/20240314092811_add_saml_name_id_format.up.sql new file mode 100644 index 000000000..0196250d2 --- /dev/null +++ b/migrations/20240314092811_add_saml_name_id_format.up.sql @@ -0,0 +1,3 @@ +do $$ begin +alter table {{ index .Options "Namespace" }}.saml_providers add column if not exists name_id_format text null; +end $$ diff --git a/migrations/20240409123726_add_phone_confirmation_sent_at.up.sql b/migrations/20240409123726_add_phone_confirmation_sent_at.up.sql new file mode 100644 index 000000000..ec7c611f7 --- /dev/null +++ b/migrations/20240409123726_add_phone_confirmation_sent_at.up.sql @@ -0,0 +1,33 @@ +-- Add phone_confirmation_sent_at column to users table +ALTER TABLE "auth"."users" ADD COLUMN IF NOT EXISTS "phone_confirmation_sent_at" timestamptz; + +-- Update the trigger that checks empty timestamps +CREATE OR REPLACE FUNCTION "auth"."set_empty_timestamps_as_null"() RETURNS TRIGGER AS $$ +BEGIN + IF NEW."created_at" IS NOT NULL AND NEW."created_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."created_at" = NULL; + END IF; + IF NEW."updated_at" IS NOT NULL AND NEW."updated_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."updated_at" = NULL; + END IF; + IF NEW."confirmed_at" IS NOT NULL AND NEW."confirmed_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."confirmed_at" = NULL; + END IF; + IF NEW."confirmation_sent_at" IS NOT NULL AND NEW."confirmation_sent_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."confirmation_sent_at" = NULL; + END IF; + IF NEW."phone_confirmation_sent_at" IS NOT NULL AND NEW."phone_confirmation_sent_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."phone_confirmation_sent_at" = NULL; + END IF; + IF NEW."recovery_sent_at" IS NOT NULL AND NEW."recovery_sent_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."recovery_sent_at" = NULL; + END IF; + IF NEW."email_change_sent_at" IS NOT NULL AND NEW."email_change_sent_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."email_change_sent_at" = NULL; + END IF; + IF NEW."phone_change_sent_at" IS NOT NULL AND NEW."phone_change_sent_at"::timestamptz = 'epoch'::timestamptz THEN + NEW."phone_change_sent_at" = NULL; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; \ No newline at end of file diff --git a/migrations/20240427152123_add_one_time_tokens_table.up.sql b/migrations/20240427152123_add_one_time_tokens_table.up.sql new file mode 100644 index 000000000..be7312656 --- /dev/null +++ b/migrations/20240427152123_add_one_time_tokens_table.up.sql @@ -0,0 +1,37 @@ +do $$ begin + create type one_time_token_type as enum ( + 'confirmation_token', + 'reauthentication_token', + 'recovery_token', + 'email_change_token_new', + 'email_change_token_current', + 'phone_change_token' + ); +exception + when duplicate_object then null; +end $$; + + +do $$ begin + create table if not exists {{ index .Options "Namespace" }}.one_time_tokens ( + id uuid primary key, + user_id uuid not null references {{ index .Options "Namespace" }}.users on delete cascade, + token_type one_time_token_type not null, + token_hash text not null, + relates_to text not null, + created_at timestamp without time zone not null default now(), + updated_at timestamp without time zone not null default now(), + check (char_length(token_hash) > 0) + ); + + begin + create index if not exists one_time_tokens_token_hash_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (token_hash); + create index if not exists one_time_tokens_relates_to_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (relates_to); + exception when others then + -- Fallback to btree indexes if hash creation fails + create index if not exists one_time_tokens_token_hash_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using btree (token_hash); + create index if not exists one_time_tokens_relates_to_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using btree (relates_to); + end; + + create unique index if not exists one_time_tokens_user_id_token_type_key on {{ index .Options "Namespace" }}.one_time_tokens (user_id, token_type); +end $$; diff --git a/migrations/20240612123726_enable_rls_update_grants.up.sql b/migrations/20240612123726_enable_rls_update_grants.up.sql new file mode 100644 index 000000000..9201e8496 --- /dev/null +++ b/migrations/20240612123726_enable_rls_update_grants.up.sql @@ -0,0 +1,36 @@ +do $$ begin + -- enable RLS policy on auth tables + alter table {{ index .Options "Namespace" }}.schema_migrations enable row level security; + alter table {{ index .Options "Namespace" }}.instances enable row level security; + alter table {{ index .Options "Namespace" }}.users enable row level security; + alter table {{ index .Options "Namespace" }}.audit_log_entries enable row level security; + alter table {{ index .Options "Namespace" }}.saml_relay_states enable row level security; + alter table {{ index .Options "Namespace" }}.refresh_tokens enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_factors enable row level security; + alter table {{ index .Options "Namespace" }}.sessions enable row level security; + alter table {{ index .Options "Namespace" }}.sso_providers enable row level security; + alter table {{ index .Options "Namespace" }}.sso_domains enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_challenges enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_amr_claims enable row level security; + alter table {{ index .Options "Namespace" }}.saml_providers enable row level security; + alter table {{ index .Options "Namespace" }}.flow_state enable row level security; + alter table {{ index .Options "Namespace" }}.identities enable row level security; + alter table {{ index .Options "Namespace" }}.one_time_tokens enable row level security; + -- allow postgres role to select from auth tables and allow it to grant select to other roles + grant select on {{ index .Options "Namespace" }}.schema_migrations to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.instances to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.users to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.audit_log_entries to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.saml_relay_states to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.refresh_tokens to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_factors to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sessions to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sso_providers to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sso_domains to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_challenges to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_amr_claims to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.saml_providers to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.flow_state to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.identities to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.one_time_tokens to postgres with grant option; +end $$; diff --git a/migrations/20240729123726_add_mfa_phone_config.up.sql b/migrations/20240729123726_add_mfa_phone_config.up.sql new file mode 100644 index 000000000..ec94d7bcb --- /dev/null +++ b/migrations/20240729123726_add_mfa_phone_config.up.sql @@ -0,0 +1,12 @@ +do $$ begin + alter type {{ index .Options "Namespace" }}.factor_type add value 'phone'; +exception + when duplicate_object then null; +end $$; + + +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists phone text unique default null; +alter table {{ index .Options "Namespace" }}.mfa_challenges add column if not exists otp_code text null; + + +create unique index if not exists unique_verified_phone_factor on {{ index .Options "Namespace" }}.mfa_factors (user_id, phone); diff --git a/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql b/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql new file mode 100644 index 000000000..bc3eea989 --- /dev/null +++ b/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql @@ -0,0 +1 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists last_challenged_at timestamptz unique default null; diff --git a/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql b/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql new file mode 100644 index 000000000..ade27eaf6 --- /dev/null +++ b/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql @@ -0,0 +1,22 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors drop constraint if exists mfa_factors_phone_key; +do $$ +begin + -- if both indexes exist, it means that the schema_migrations table was truncated and the migrations had to be rerun + if ( + select count(*) = 2 + from pg_indexes + where indexname in ('unique_verified_phone_factor', 'unique_phone_factor_per_user') + and schemaname = '{{ index .Options "Namespace" }}' + ) then + execute 'drop index {{ index .Options "Namespace" }}.unique_verified_phone_factor'; + end if; + + if exists ( + select 1 + from pg_indexes + where indexname = 'unique_verified_phone_factor' + and schemaname = '{{ index .Options "Namespace" }}' + ) then + execute 'alter index {{ index .Options "Namespace" }}.unique_verified_phone_factor rename to unique_phone_factor_per_user'; + end if; +end $$; diff --git a/migrations/20241009103726_add_web_authn.up.sql b/migrations/20241009103726_add_web_authn.up.sql new file mode 100644 index 000000000..04d897265 --- /dev/null +++ b/migrations/20241009103726_add_web_authn.up.sql @@ -0,0 +1,3 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists web_authn_credential jsonb null; +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists web_authn_aaguid uuid null; +alter table {{ index .Options "Namespace" }}.mfa_challenges add column if not exists web_authn_session_data jsonb null; diff --git a/models/audit_log_entry.go b/models/audit_log_entry.go deleted file mode 100644 index 3aa39b742..000000000 --- a/models/audit_log_entry.go +++ /dev/null @@ -1,133 +0,0 @@ -package models - -import ( - "bytes" - "fmt" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -type AuditAction string -type auditLogType string - -const ( - LoginAction AuditAction = "login" - LogoutAction AuditAction = "logout" - InviteAcceptedAction AuditAction = "invite_accepted" - UserSignedUpAction AuditAction = "user_signedup" - UserInvitedAction AuditAction = "user_invited" - UserDeletedAction AuditAction = "user_deleted" - UserModifiedAction AuditAction = "user_modified" - UserRecoveryRequestedAction AuditAction = "user_recovery_requested" - UserConfirmationRequestedAction AuditAction = "user_confirmation_requested" - UserRepeatedSignUpAction AuditAction = "user_repeated_signup" - TokenRevokedAction AuditAction = "token_revoked" - TokenRefreshedAction AuditAction = "token_refreshed" - - account auditLogType = "account" - team auditLogType = "team" - token auditLogType = "token" - user auditLogType = "user" -) - -var actionLogTypeMap = map[AuditAction]auditLogType{ - LoginAction: account, - LogoutAction: account, - InviteAcceptedAction: account, - UserSignedUpAction: team, - UserInvitedAction: team, - UserDeletedAction: team, - TokenRevokedAction: token, - TokenRefreshedAction: token, - UserModifiedAction: user, - UserRecoveryRequestedAction: user, - UserConfirmationRequestedAction: user, - UserRepeatedSignUpAction: user, -} - -// AuditLogEntry is the database model for audit log entries. -type AuditLogEntry struct { - InstanceID uuid.UUID `json:"-" db:"instance_id"` - ID uuid.UUID `json:"id" db:"id"` - - Payload JSONMap `json:"payload" db:"payload"` - - CreatedAt time.Time `json:"created_at" db:"created_at"` -} - -func (AuditLogEntry) TableName() string { - tableName := "audit_log_entries" - return tableName -} - -func NewAuditLogEntry(tx *storage.Connection, instanceID uuid.UUID, actor *User, action AuditAction, traits map[string]interface{}) error { - id, err := uuid.NewV4() - if err != nil { - return errors.Wrap(err, "Error generating unique id") - } - - username := actor.GetEmail() - - if actor.GetPhone() != "" { - username = actor.GetPhone() - } - - l := AuditLogEntry{ - InstanceID: instanceID, - ID: id, - Payload: JSONMap{ - "timestamp": time.Now().UTC().Format(time.RFC3339), - "actor_id": actor.ID, - "actor_username": username, - "action": action, - "log_type": actionLogTypeMap[action], - }, - } - - if name, ok := actor.UserMetaData["full_name"]; ok { - l.Payload["actor_name"] = name - } - - if traits != nil { - l.Payload["traits"] = traits - } - - return errors.Wrap(tx.Create(&l), "Database error creating audit log entry") -} - -func FindAuditLogEntries(tx *storage.Connection, instanceID uuid.UUID, filterColumns []string, filterValue string, pageParams *Pagination) ([]*AuditLogEntry, error) { - q := tx.Q().Order("created_at desc").Where("instance_id = ?", instanceID) - - if len(filterColumns) > 0 && filterValue != "" { - lf := "%" + filterValue + "%" - - builder := bytes.NewBufferString("(") - values := make([]interface{}, len(filterColumns)) - - for idx, col := range filterColumns { - builder.WriteString(fmt.Sprintf("payload->>'%s' ILIKE ?", col)) - values[idx] = lf - - if idx+1 < len(filterColumns) { - builder.WriteString(" OR ") - } - } - builder.WriteString(")") - - q = q.Where(builder.String(), values...) - } - - logs := []*AuditLogEntry{} - var err error - if pageParams != nil { - err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&logs) - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) - } else { - err = q.All(&logs) - } - - return logs, err -} diff --git a/models/connection.go b/models/connection.go deleted file mode 100644 index 7c01a1de8..000000000 --- a/models/connection.go +++ /dev/null @@ -1,46 +0,0 @@ -package models - -import ( - "github.com/gobuffalo/pop/v5" - "github.com/netlify/gotrue/storage" -) - -type Pagination struct { - Page uint64 - PerPage uint64 - Count uint64 -} - -func (p *Pagination) Offset() uint64 { - return (p.Page - 1) * p.PerPage -} - -type SortDirection string - -const Ascending SortDirection = "ASC" -const Descending SortDirection = "DESC" -const CreatedAt = "created_at" - -type SortParams struct { - Fields []SortField -} - -type SortField struct { - Name string - Dir SortDirection -} - -func TruncateAll(conn *storage.Connection) error { - return conn.Transaction(func(tx *storage.Connection) error { - if err := tx.RawQuery("TRUNCATE " + (&pop.Model{Value: User{}}).TableName() + " CASCADE").Exec(); err != nil { - return err - } - if err := tx.RawQuery("TRUNCATE " + (&pop.Model{Value: RefreshToken{}}).TableName() + " CASCADE").Exec(); err != nil { - return err - } - if err := tx.RawQuery("TRUNCATE " + (&pop.Model{Value: AuditLogEntry{}}).TableName() + " CASCADE").Exec(); err != nil { - return err - } - return tx.RawQuery("TRUNCATE " + (&pop.Model{Value: Instance{}}).TableName() + " CASCADE").Exec() - }) -} diff --git a/models/db_test.go b/models/db_test.go deleted file mode 100644 index 84a10e048..000000000 --- a/models/db_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package models_test - -import ( - "testing" - - "github.com/gobuffalo/pop/v5" - "github.com/netlify/gotrue/models" - "github.com/stretchr/testify/assert" -) - -func TestTableNameNamespacing(t *testing.T) { - cases := []struct { - expected string - value interface{} - }{ - {expected: "audit_log_entries", value: []*models.AuditLogEntry{}}, - {expected: "instances", value: []*models.Instance{}}, - {expected: "refresh_tokens", value: []*models.RefreshToken{}}, - {expected: "users", value: []*models.User{}}, - } - - for _, tc := range cases { - m := &pop.Model{Value: tc.value} - assert.Equal(t, tc.expected, m.TableName()) - } -} diff --git a/models/errors.go b/models/errors.go deleted file mode 100644 index 6c33c70c2..000000000 --- a/models/errors.go +++ /dev/null @@ -1,61 +0,0 @@ -package models - -// IsNotFoundError returns whether an error represents a "not found" error. -func IsNotFoundError(err error) bool { - switch err.(type) { - case UserNotFoundError: - return true - case ConfirmationTokenNotFoundError: - return true - case RefreshTokenNotFoundError: - return true - case InstanceNotFoundError: - return true - case TotpSecretNotFoundError: - return true - case IdentityNotFoundError: - return true - } - return false -} - -// UserNotFoundError represents when a user is not found. -type UserNotFoundError struct{} - -func (e UserNotFoundError) Error() string { - return "User not found" -} - -// IdentityNotFoundError represents when an identity is not found. -type IdentityNotFoundError struct{} - -func (e IdentityNotFoundError) Error() string { - return "Identity not found" -} - -// ConfirmationTokenNotFoundError represents when a confirmation token is not found. -type ConfirmationTokenNotFoundError struct{} - -func (e ConfirmationTokenNotFoundError) Error() string { - return "Confirmation Token not found" -} - -// RefreshTokenNotFoundError represents when a refresh token is not found. -type RefreshTokenNotFoundError struct{} - -func (e RefreshTokenNotFoundError) Error() string { - return "Refresh Token not found" -} - -// InstanceNotFoundError represents when an instance is not found. -type InstanceNotFoundError struct{} - -func (e InstanceNotFoundError) Error() string { - return "Instance not found" -} - -type TotpSecretNotFoundError struct{} - -func (e TotpSecretNotFoundError) Error() string { - return "Totp Secret not found" -} diff --git a/models/identity.go b/models/identity.go deleted file mode 100644 index 74b9383b2..000000000 --- a/models/identity.go +++ /dev/null @@ -1,84 +0,0 @@ -package models - -import ( - "database/sql" - "time" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -type Identity struct { - ID string `json:"id" db:"id"` - UserID uuid.UUID `json:"user_id" db:"user_id"` - IdentityData JSONMap `json:"identity_data,omitempty" db:"identity_data"` - Provider string `json:"provider" db:"provider"` - LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -func (Identity) TableName() string { - tableName := "identities" - return tableName -} - -// NewIdentity returns an identity associated to the user's id. -func NewIdentity(user *User, provider string, identityData map[string]interface{}) (*Identity, error) { - id, ok := identityData["sub"] - if !ok { - return nil, errors.New("Error missing provider id") - } - now := time.Now() - - identity := &Identity{ - ID: id.(string), - UserID: user.ID, - IdentityData: identityData, - Provider: provider, - LastSignInAt: &now, - } - - return identity, nil -} - -// FindIdentityById searches for an identity with the matching provider_id and provider given. -func FindIdentityByIdAndProvider(tx *storage.Connection, providerId, provider string) (*Identity, error) { - identity := &Identity{} - if err := tx.Q().Where("id = ? AND provider = ?", providerId, provider).First(identity); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, IdentityNotFoundError{} - } - return nil, errors.Wrap(err, "error finding identity") - } - return identity, nil -} - -// FindIdentitiesByUser returns all identities associated to a user -func FindIdentitiesByUser(tx *storage.Connection, user *User) ([]*Identity, error) { - identities := []*Identity{} - if err := tx.Q().Where("user_id = ?", user.ID).All(&identities); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return identities, nil - } - return nil, errors.Wrap(err, "error finding identities") - } - return identities, nil -} - -// FindProvidersByUser returns all providers associated to a user -func FindProvidersByUser(tx *storage.Connection, user *User) ([]string, error) { - identities := []Identity{} - providers := make([]string, 0) - if err := tx.Q().Select("provider").Where("user_id = ?", user.ID).All(&identities); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return providers, nil - } - return nil, errors.Wrap(err, "error finding providers") - } - for _, identity := range identities { - providers = append(providers, identity.Provider) - } - return providers, nil -} diff --git a/models/instance.go b/models/instance.go deleted file mode 100644 index 009c456a6..000000000 --- a/models/instance.go +++ /dev/null @@ -1,89 +0,0 @@ -package models - -import ( - "database/sql" - "time" - - "github.com/gobuffalo/pop/v5" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -const baseConfigKey = "" - -type Instance struct { - ID uuid.UUID `json:"id" db:"id"` - // Netlify UUID - UUID uuid.UUID `json:"uuid,omitempty" db:"uuid"` - - BaseConfig *conf.Configuration `json:"config" db:"raw_base_config"` - - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -func (Instance) TableName() string { - tableName := "instances" - return tableName -} - -// Config loads the the base configuration values with defaults. -func (i *Instance) Config() (*conf.Configuration, error) { - if i.BaseConfig == nil { - return nil, errors.New("no configuration data available") - } - - baseConf := &conf.Configuration{} - *baseConf = *i.BaseConfig - baseConf.ApplyDefaults() - - return baseConf, nil -} - -// UpdateConfig updates the base config -func (i *Instance) UpdateConfig(tx *storage.Connection, config *conf.Configuration) error { - i.BaseConfig = config - return tx.UpdateOnly(i, "raw_base_config") -} - -// GetInstance finds an instance by ID -func GetInstance(tx *storage.Connection, instanceID uuid.UUID) (*Instance, error) { - instance := Instance{} - if err := tx.Find(&instance, instanceID); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, InstanceNotFoundError{} - } - return nil, errors.Wrap(err, "error finding instance") - } - return &instance, nil -} - -func GetInstanceByUUID(tx *storage.Connection, uuid uuid.UUID) (*Instance, error) { - instance := Instance{} - if err := tx.Where("uuid = ?", uuid).First(&instance); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, InstanceNotFoundError{} - } - return nil, errors.Wrap(err, "error finding instance") - } - return &instance, nil -} - -func DeleteInstance(conn *storage.Connection, instance *Instance) error { - return conn.Transaction(func(tx *storage.Connection) error { - delModels := map[string]*pop.Model{ - "user": &pop.Model{Value: &User{}}, - "refresh token": &pop.Model{Value: &RefreshToken{}}, - } - - for name, dm := range delModels { - if err := tx.RawQuery("DELETE FROM "+dm.TableName()+" WHERE instance_id = ?", instance.ID).Exec(); err != nil { - return errors.Wrapf(err, "Error deleting %s records", name) - } - } - - return errors.Wrap(tx.Destroy(instance), "Error deleting instance record") - }) -} diff --git a/models/refresh_token.go b/models/refresh_token.go deleted file mode 100644 index 054806e4e..000000000 --- a/models/refresh_token.go +++ /dev/null @@ -1,97 +0,0 @@ -package models - -import ( - "time" - - "github.com/gobuffalo/pop/v5" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/crypto" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" -) - -// RefreshToken is the database model for refresh tokens. -type RefreshToken struct { - InstanceID uuid.UUID `json:"-" db:"instance_id"` - ID int64 `db:"id"` - - Token string `db:"token"` - - UserID uuid.UUID `db:"user_id"` - - Parent storage.NullString `db:"parent"` - - Revoked bool `db:"revoked"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` -} - -func (RefreshToken) TableName() string { - tableName := "refresh_tokens" - return tableName -} - -// GrantAuthenticatedUser creates a refresh token for the provided user. -func GrantAuthenticatedUser(tx *storage.Connection, user *User) (*RefreshToken, error) { - return createRefreshToken(tx, user, nil) -} - -// GrantRefreshTokenSwap swaps a refresh token for a new one, revoking the provided token. -func GrantRefreshTokenSwap(tx *storage.Connection, user *User, token *RefreshToken) (*RefreshToken, error) { - var newToken *RefreshToken - err := tx.Transaction(func(rtx *storage.Connection) error { - var terr error - if terr = NewAuditLogEntry(tx, user.InstanceID, user, TokenRevokedAction, nil); terr != nil { - return errors.Wrap(terr, "error creating audit log entry") - } - - token.Revoked = true - if terr = tx.UpdateOnly(token, "revoked"); terr != nil { - return terr - } - newToken, terr = createRefreshToken(rtx, user, token) - return terr - }) - return newToken, err -} - -// RevokeTokenFamily revokes all refresh tokens that descended from the provided token. -func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error { - err := tx.RawQuery(` - with recursive token_family as ( - select id, user_id, token, revoked, parent from refresh_tokens where parent = ? - union - select r.id, r.user_id, r.token, r.revoked, r.parent from `+(&pop.Model{Value: RefreshToken{}}).TableName()+` r inner join token_family t on t.token = r.parent - ) - update `+(&pop.Model{Value: RefreshToken{}}).TableName()+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec() - if err != nil { - return err - } - return nil -} - -// Logout deletes all refresh tokens for a user. -func Logout(tx *storage.Connection, instanceID uuid.UUID, id uuid.UUID) error { - return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: RefreshToken{}}).TableName()+" WHERE instance_id = ? AND user_id = ?", instanceID, id).Exec() -} - -func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken) (*RefreshToken, error) { - token := &RefreshToken{ - InstanceID: user.InstanceID, - UserID: user.ID, - Token: crypto.SecureToken(), - Parent: "", - } - if oldToken != nil { - token.Parent = storage.NullString(oldToken.Token) - } - - if err := tx.Create(token); err != nil { - return nil, errors.Wrap(err, "error creating refresh token") - } - - if err := user.UpdateLastSignInAt(tx); err != nil { - return nil, errors.Wrap(err, "error update user`s last_sign_in field") - } - return token, nil -} diff --git a/models/user.go b/models/user.go deleted file mode 100644 index d2f7083e1..000000000 --- a/models/user.go +++ /dev/null @@ -1,475 +0,0 @@ -package models - -import ( - "database/sql" - "strings" - "time" - - "github.com/gobuffalo/pop/v5" - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/storage" - "github.com/pkg/errors" - "golang.org/x/crypto/bcrypt" -) - -const SystemUserID = "0" - -var SystemUserUUID = uuid.Nil - -// User respresents a registered user with email/password authentication -type User struct { - InstanceID uuid.UUID `json:"-" db:"instance_id"` - ID uuid.UUID `json:"id" db:"id"` - - Aud string `json:"aud" db:"aud"` - Role string `json:"role" db:"role"` - Email storage.NullString `json:"email" db:"email"` - EncryptedPassword string `json:"-" db:"encrypted_password"` - EmailConfirmedAt *time.Time `json:"email_confirmed_at,omitempty" db:"email_confirmed_at"` - InvitedAt *time.Time `json:"invited_at,omitempty" db:"invited_at"` - - Phone storage.NullString `json:"phone" db:"phone"` - PhoneConfirmedAt *time.Time `json:"phone_confirmed_at,omitempty" db:"phone_confirmed_at"` - - ConfirmationToken string `json:"-" db:"confirmation_token"` - ConfirmationSentAt *time.Time `json:"confirmation_sent_at,omitempty" db:"confirmation_sent_at"` - - // For backward compatibility only. Use EmailConfirmedAt or PhoneConfirmedAt instead. - ConfirmedAt *time.Time `json:"confirmed_at,omitempty" db:"confirmed_at" rw:"r"` - - RecoveryToken string `json:"-" db:"recovery_token"` - RecoverySentAt *time.Time `json:"recovery_sent_at,omitempty" db:"recovery_sent_at"` - - EmailChangeTokenCurrent string `json:"-" db:"email_change_token_current"` - EmailChangeTokenNew string `json:"-" db:"email_change_token_new"` - EmailChange string `json:"new_email,omitempty" db:"email_change"` - EmailChangeSentAt *time.Time `json:"email_change_sent_at,omitempty" db:"email_change_sent_at"` - EmailChangeConfirmStatus int `json:"-" db:"email_change_confirm_status"` - - PhoneChangeToken string `json:"-" db:"phone_change_token"` - PhoneChange string `json:"new_phone,omitempty" db:"phone_change"` - PhoneChangeSentAt *time.Time `json:"phone_change_sent_at,omitempty" db:"phone_change_sent_at"` - - LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` - - AppMetaData JSONMap `json:"app_metadata" db:"raw_app_meta_data"` - UserMetaData JSONMap `json:"user_metadata" db:"raw_user_meta_data"` - - IsSuperAdmin bool `json:"-" db:"is_super_admin"` - Identities []Identity `json:"identities" has_many:"identities"` - - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` - BannedUntil *time.Time `json:"banned_until,omitempty" db:"banned_until"` -} - -// NewUser initializes a new user from an email, password and user data. -// TODO: Refactor NewUser to take in phone as an arg -func NewUser(instanceID uuid.UUID, email, password, aud string, userData map[string]interface{}) (*User, error) { - id, err := uuid.NewV4() - if err != nil { - return nil, errors.Wrap(err, "Error generating unique id") - } - pw, err := hashPassword(password) - if err != nil { - return nil, err - } - if userData == nil { - userData = make(map[string]interface{}) - } - user := &User{ - InstanceID: instanceID, - ID: id, - Aud: aud, - Email: storage.NullString(strings.ToLower(email)), - UserMetaData: userData, - EncryptedPassword: pw, - } - return user, nil -} - -// NewSystemUser returns a user with the id as SystemUserUUID -func NewSystemUser(instanceID uuid.UUID, aud string) *User { - return &User{ - InstanceID: instanceID, - ID: SystemUserUUID, - Aud: aud, - IsSuperAdmin: true, - } -} - -// TableName overrides the table name used by pop -func (User) TableName() string { - tableName := "users" - return tableName -} - -// BeforeCreate is invoked before a create operation is ran -func (u *User) BeforeCreate(tx *pop.Connection) error { - return u.BeforeUpdate(tx) -} - -// BeforeUpdate is invoked before an update operation is ran -func (u *User) BeforeUpdate(tx *pop.Connection) error { - if u.ID == SystemUserUUID { - return errors.New("Cannot persist system user") - } - - return nil -} - -// BeforeSave is invoked before the user is saved to the database -func (u *User) BeforeSave(tx *pop.Connection) error { - if u.ID == SystemUserUUID { - return errors.New("Cannot persist system user") - } - - if u.EmailConfirmedAt != nil && u.EmailConfirmedAt.IsZero() { - u.EmailConfirmedAt = nil - } - if u.PhoneConfirmedAt != nil && u.PhoneConfirmedAt.IsZero() { - u.PhoneConfirmedAt = nil - } - if u.InvitedAt != nil && u.InvitedAt.IsZero() { - u.InvitedAt = nil - } - if u.ConfirmationSentAt != nil && u.ConfirmationSentAt.IsZero() { - u.ConfirmationSentAt = nil - } - if u.RecoverySentAt != nil && u.RecoverySentAt.IsZero() { - u.RecoverySentAt = nil - } - if u.EmailChangeSentAt != nil && u.EmailChangeSentAt.IsZero() { - u.EmailChangeSentAt = nil - } - if u.PhoneChangeSentAt != nil && u.PhoneChangeSentAt.IsZero() { - u.PhoneChangeSentAt = nil - } - if u.LastSignInAt != nil && u.LastSignInAt.IsZero() { - u.LastSignInAt = nil - } - if u.BannedUntil != nil && u.BannedUntil.IsZero() { - u.BannedUntil = nil - } - return nil -} - -// IsConfirmed checks if a user has already been -// registered and confirmed. -func (u *User) IsConfirmed() bool { - return u.EmailConfirmedAt != nil -} - -// IsPhoneConfirmed checks if a user's phone has already been -// registered and confirmed. -func (u *User) IsPhoneConfirmed() bool { - return u.PhoneConfirmedAt != nil -} - -// SetRole sets the users Role to roleName -func (u *User) SetRole(tx *storage.Connection, roleName string) error { - u.Role = strings.TrimSpace(roleName) - return tx.UpdateOnly(u, "role") -} - -// HasRole returns true when the users role is set to roleName -func (u *User) HasRole(roleName string) bool { - return u.Role == roleName -} - -// GetEmail returns the user's email as a string -func (u *User) GetEmail() string { - return string(u.Email) -} - -// GetPhone returns the user's phone number as a string -func (u *User) GetPhone() string { - return string(u.Phone) -} - -// UpdateUserMetaData sets all user data from a map of updates, -// ensuring that it doesn't override attributes that are not -// in the provided map. -func (u *User) UpdateUserMetaData(tx *storage.Connection, updates map[string]interface{}) error { - if u.UserMetaData == nil { - u.UserMetaData = updates - } else if updates != nil { - for key, value := range updates { - if value != nil { - u.UserMetaData[key] = value - } else { - delete(u.UserMetaData, key) - } - } - } - return tx.UpdateOnly(u, "raw_user_meta_data") -} - -// UpdateAppMetaData updates all app data from a map of updates -func (u *User) UpdateAppMetaData(tx *storage.Connection, updates map[string]interface{}) error { - if u.AppMetaData == nil { - u.AppMetaData = updates - } else if updates != nil { - for key, value := range updates { - if value != nil { - u.AppMetaData[key] = value - } else { - delete(u.AppMetaData, key) - } - } - } - return tx.UpdateOnly(u, "raw_app_meta_data") -} - -// UpdateAppMetaDataProviders updates the provider field in AppMetaData column -func (u *User) UpdateAppMetaDataProviders(tx *storage.Connection) error { - providers, terr := FindProvidersByUser(tx, u) - if terr != nil { - return terr - } - return u.UpdateAppMetaData(tx, map[string]interface{}{ - "providers": providers, - }) -} - -// SetEmail sets the user's email -func (u *User) SetEmail(tx *storage.Connection, email string) error { - u.Email = storage.NullString(email) - return tx.UpdateOnly(u, "email") -} - -// SetPhone sets the user's phone -func (u *User) SetPhone(tx *storage.Connection, phone string) error { - u.Phone = storage.NullString(phone) - return tx.UpdateOnly(u, "phone") -} - -// hashPassword generates a hashed password from a plaintext string -func hashPassword(password string) (string, error) { - pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return "", err - } - return string(pw), nil -} - -// UpdatePassword updates the user's password -func (u *User) UpdatePassword(tx *storage.Connection, password string) error { - pw, err := hashPassword(password) - if err != nil { - return err - } - u.EncryptedPassword = pw - return tx.UpdateOnly(u, "encrypted_password") -} - -// UpdatePhone updates the user's phone -func (u *User) UpdatePhone(tx *storage.Connection, phone string) error { - u.Phone = storage.NullString(phone) - return tx.UpdateOnly(u, "phone") -} - -// Authenticate a user from a password -func (u *User) Authenticate(password string) bool { - err := bcrypt.CompareHashAndPassword([]byte(u.EncryptedPassword), []byte(password)) - return err == nil -} - -// Confirm resets the confimation token and sets the confirm timestamp -func (u *User) Confirm(tx *storage.Connection) error { - u.ConfirmationToken = "" - now := time.Now() - u.EmailConfirmedAt = &now - return tx.UpdateOnly(u, "confirmation_token", "email_confirmed_at") -} - -// ConfirmPhone resets the confimation token and sets the confirm timestamp -func (u *User) ConfirmPhone(tx *storage.Connection) error { - u.ConfirmationToken = "" - now := time.Now() - u.PhoneConfirmedAt = &now - return tx.UpdateOnly(u, "confirmation_token", "phone_confirmed_at") -} - -// UpdateLastSignInAt update field last_sign_in_at for user according to specified field -func (u *User) UpdateLastSignInAt(tx *storage.Connection) error { - return tx.UpdateOnly(u, "last_sign_in_at") -} - -// ConfirmEmailChange confirm the change of email for a user -func (u *User) ConfirmEmailChange(tx *storage.Connection, status int) error { - u.Email = storage.NullString(u.EmailChange) - u.EmailChange = "" - u.EmailChangeTokenCurrent = "" - u.EmailChangeTokenNew = "" - u.EmailChangeConfirmStatus = status - return tx.UpdateOnly( - u, - "email", - "email_change", - "email_change_token_current", - "email_change_token_new", - "email_change_confirm_status", - ) -} - -// ConfirmPhoneChange confirms the change of phone for a user -func (u *User) ConfirmPhoneChange(tx *storage.Connection) error { - u.Phone = storage.NullString(u.PhoneChange) - u.PhoneChange = "" - u.PhoneChangeToken = "" - now := time.Now() - u.PhoneConfirmedAt = &now - return tx.UpdateOnly(u, "phone", "phone_change", "phone_change_token", "phone_confirmed_at") -} - -// Recover resets the recovery token -func (u *User) Recover(tx *storage.Connection) error { - u.RecoveryToken = "" - return tx.UpdateOnly(u, "recovery_token") -} - -// CountOtherUsers counts how many other users exist besides the one provided -func CountOtherUsers(tx *storage.Connection, instanceID, id uuid.UUID) (int, error) { - userCount, err := tx.Q().Where("instance_id = ? and id != ?", instanceID, id).Count(&User{}) - return userCount, errors.Wrap(err, "error finding registered users") -} - -func findUser(tx *storage.Connection, query string, args ...interface{}) (*User, error) { - obj := &User{} - if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, UserNotFoundError{} - } - return nil, errors.Wrap(err, "error finding user") - } - - return obj, nil -} - -// FindUserByConfirmationToken finds users with the matching confirmation token. -func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { - user, err := findUser(tx, "confirmation_token = ?", token) - if err != nil { - return nil, ConfirmationTokenNotFoundError{} - } - return user, nil -} - -// FindUserByEmailAndAudience finds a user with the matching email and audience. -func FindUserByEmailAndAudience(tx *storage.Connection, instanceID uuid.UUID, email, aud string) (*User, error) { - return findUser(tx, "instance_id = ? and LOWER(email) = ? and aud = ?", instanceID, strings.ToLower(email), aud) -} - -// FindUserByPhoneAndAudience finds a user with the matching email and audience. -func FindUserByPhoneAndAudience(tx *storage.Connection, instanceID uuid.UUID, phone, aud string) (*User, error) { - return findUser(tx, "instance_id = ? and phone = ? and aud = ?", instanceID, phone, aud) -} - -// FindUserByID finds a user matching the provided ID. -func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { - return findUser(tx, "id = ?", id) -} - -// FindUserByInstanceIDAndID finds a user matching the provided ID. -func FindUserByInstanceIDAndID(tx *storage.Connection, instanceID, id uuid.UUID) (*User, error) { - return findUser(tx, "instance_id = ? and id = ?", instanceID, id) -} - -// FindUserByRecoveryToken finds a user with the matching recovery token. -func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { - return findUser(tx, "recovery_token = ?", token) -} - -// FindUserByEmailChangeToken finds a user with the matching email change token. -func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { - return findUser(tx, "email_change_token_current = ? or email_change_token_new = ?", token, token) -} - -// FindUserWithRefreshToken finds a user from the provided refresh token. -func FindUserWithRefreshToken(tx *storage.Connection, token string) (*User, *RefreshToken, error) { - refreshToken := &RefreshToken{} - if err := tx.Where("token = ?", token).First(refreshToken); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, nil, RefreshTokenNotFoundError{} - } - return nil, nil, errors.Wrap(err, "error finding refresh token") - } - - user, err := findUser(tx, "id = ?", refreshToken.UserID) - if err != nil { - return nil, nil, err - } - - return user, refreshToken, nil -} - -// FindUsersInAudience finds users with the matching audience. -func FindUsersInAudience(tx *storage.Connection, instanceID uuid.UUID, aud string, pageParams *Pagination, sortParams *SortParams, filter string) ([]*User, error) { - users := []*User{} - q := tx.Q().Where("instance_id = ? and aud = ?", instanceID, aud) - - if filter != "" { - lf := "%" + filter + "%" - // we must specify the collation in order to get case insensitive search for the JSON column - q = q.Where("(email LIKE ? OR raw_user_meta_data->>'full_name' ILIKE ?)", lf, lf) - } - - if sortParams != nil && len(sortParams.Fields) > 0 { - for _, field := range sortParams.Fields { - q = q.Order(field.Name + " " + string(field.Dir)) - } - } - - var err error - if pageParams != nil { - err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&users) - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) - } else { - err = q.All(&users) - } - - return users, err -} - -// FindUserWithPhoneAndPhoneChangeToken finds a user with the matching phone and phone change token -func FindUserWithPhoneAndPhoneChangeToken(tx *storage.Connection, phone, token string) (*User, error) { - return findUser(tx, "phone = ? and phone_change_token = ?", phone, token) -} - -// IsDuplicatedEmail returns whether a user exists with a matching email and audience. -func IsDuplicatedEmail(tx *storage.Connection, instanceID uuid.UUID, email, aud string) (bool, error) { - _, err := FindUserByEmailAndAudience(tx, instanceID, email, aud) - if err != nil { - if IsNotFoundError(err) { - return false, nil - } - return false, err - } - return true, nil -} - -// IsDuplicatedPhone checks if the phone number already exists in the users table -func IsDuplicatedPhone(tx *storage.Connection, instanceID uuid.UUID, phone, aud string) (bool, error) { - _, err := FindUserByPhoneAndAudience(tx, instanceID, phone, aud) - if err != nil { - if IsNotFoundError(err) { - return false, nil - } - return false, err - } - return true, nil -} - -// IsBanned checks if a user is banned or not -func (u *User) IsBanned() bool { - if u.BannedUntil == nil { - return false - } - return time.Now().Before(*u.BannedUntil) -} - -func (u *User) UpdateBannedUntil(tx *storage.Connection) error { - return tx.UpdateOnly(u, "banned_until") - -} diff --git a/models/user_test.go b/models/user_test.go deleted file mode 100644 index 752adf44d..000000000 --- a/models/user_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package models - -import ( - "testing" - - "github.com/gofrs/uuid" - "github.com/netlify/gotrue/conf" - "github.com/netlify/gotrue/storage" - "github.com/netlify/gotrue/storage/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -const modelsTestConfig = "../hack/test.env" - -type UserTestSuite struct { - suite.Suite - db *storage.Connection -} - -func (ts *UserTestSuite) SetupTest() { - TruncateAll(ts.db) -} - -func TestUser(t *testing.T) { - globalConfig, err := conf.LoadGlobal(modelsTestConfig) - require.NoError(t, err) - - conn, err := test.SetupDBConnection(globalConfig) - require.NoError(t, err) - - ts := &UserTestSuite{ - db: conn, - } - defer ts.db.Close() - - suite.Run(t, ts) -} - -func (ts *UserTestSuite) TestUpdateAppMetadata() { - u, err := NewUser(uuid.Nil, "", "", "", nil) - require.NoError(ts.T(), err) - require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, make(map[string]interface{}))) - - require.NotNil(ts.T(), u.AppMetaData) - - require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ - "foo": "bar", - })) - - require.Equal(ts.T(), "bar", u.AppMetaData["foo"]) - require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ - "foo": nil, - })) - require.Len(ts.T(), u.AppMetaData, 0) - require.Equal(ts.T(), nil, u.AppMetaData["foo"]) -} - -func (ts *UserTestSuite) TestUpdateUserMetadata() { - u, err := NewUser(uuid.Nil, "", "", "", nil) - require.NoError(ts.T(), err) - require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, make(map[string]interface{}))) - - require.NotNil(ts.T(), u.UserMetaData) - - require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ - "foo": "bar", - })) - - require.Equal(ts.T(), "bar", u.UserMetaData["foo"]) - require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ - "foo": nil, - })) - require.Len(ts.T(), u.UserMetaData, 0) - require.Equal(ts.T(), nil, u.UserMetaData["foo"]) -} - -func (ts *UserTestSuite) TestFindUserByConfirmationToken() { - u := ts.createUser() - - n, err := FindUserByConfirmationToken(ts.db, u.ConfirmationToken) - require.NoError(ts.T(), err) - require.Equal(ts.T(), u.ID, n.ID) -} - -func (ts *UserTestSuite) TestFindUserByEmailAndAudience() { - u := ts.createUser() - - n, err := FindUserByEmailAndAudience(ts.db, u.InstanceID, u.GetEmail(), "test") - require.NoError(ts.T(), err) - require.Equal(ts.T(), u.ID, n.ID) - - _, err = FindUserByEmailAndAudience(ts.db, u.InstanceID, u.GetEmail(), "invalid") - require.EqualError(ts.T(), err, UserNotFoundError{}.Error()) -} - -func (ts *UserTestSuite) TestFindUsersInAudience() { - u := ts.createUser() - - n, err := FindUsersInAudience(ts.db, u.InstanceID, u.Aud, nil, nil, "") - require.NoError(ts.T(), err) - require.Len(ts.T(), n, 1) - - p := Pagination{ - Page: 1, - PerPage: 50, - } - n, err = FindUsersInAudience(ts.db, u.InstanceID, u.Aud, &p, nil, "") - require.NoError(ts.T(), err) - require.Len(ts.T(), n, 1) - assert.Equal(ts.T(), uint64(1), p.Count) - - sp := &SortParams{ - Fields: []SortField{ - SortField{Name: "created_at", Dir: Descending}, - }, - } - n, err = FindUsersInAudience(ts.db, u.InstanceID, u.Aud, nil, sp, "") - require.NoError(ts.T(), err) - require.Len(ts.T(), n, 1) -} - -func (ts *UserTestSuite) TestFindUserByID() { - u := ts.createUser() - - n, err := FindUserByID(ts.db, u.ID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), u.ID, n.ID) -} - -func (ts *UserTestSuite) TestFindUserByInstanceIDAndID() { - u := ts.createUser() - - n, err := FindUserByInstanceIDAndID(ts.db, u.InstanceID, u.ID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), u.ID, n.ID) -} - -func (ts *UserTestSuite) TestFindUserByRecoveryToken() { - u := ts.createUser() - u.RecoveryToken = "asdf" - - err := ts.db.Update(u) - require.NoError(ts.T(), err) - - n, err := FindUserByRecoveryToken(ts.db, u.RecoveryToken) - require.NoError(ts.T(), err) - - require.Equal(ts.T(), u.ID, n.ID) -} - -func (ts *UserTestSuite) TestFindUserWithRefreshToken() { - u := ts.createUser() - r, err := GrantAuthenticatedUser(ts.db, u) - require.NoError(ts.T(), err) - - n, nr, err := FindUserWithRefreshToken(ts.db, r.Token) - require.NoError(ts.T(), err) - require.Equal(ts.T(), r.ID, nr.ID) - require.Equal(ts.T(), u.ID, n.ID) -} - -func (ts *UserTestSuite) TestIsDuplicatedEmail() { - u := ts.createUserWithEmail("david.calavera@netlify.com") - - e, err := IsDuplicatedEmail(ts.db, u.InstanceID, "david.calavera@netlify.com", "test") - require.NoError(ts.T(), err) - require.True(ts.T(), e, "expected email to be duplicated") - - e, err = IsDuplicatedEmail(ts.db, u.InstanceID, "davidcalavera@netlify.com", "test") - require.NoError(ts.T(), err) - require.False(ts.T(), e, "expected email to not be duplicated") - - e, err = IsDuplicatedEmail(ts.db, u.InstanceID, "david@netlify.com", "test") - require.NoError(ts.T(), err) - require.False(ts.T(), e, "expected same email to not be duplicated") - - e, err = IsDuplicatedEmail(ts.db, u.InstanceID, "david.calavera@netlify.com", "other-aud") - require.NoError(ts.T(), err) - require.False(ts.T(), e, "expected same email to not be duplicated") -} - -func (ts *UserTestSuite) createUser() *User { - return ts.createUserWithEmail("david@netlify.com") -} - -func (ts *UserTestSuite) createUserWithEmail(email string) *User { - user, err := NewUser(uuid.Nil, email, "secret", "test", nil) - require.NoError(ts.T(), err) - - err = ts.db.Create(user) - require.NoError(ts.T(), err) - - return user -} diff --git a/netlify.toml b/netlify.toml deleted file mode 100644 index 4bf4f5fe4..000000000 --- a/netlify.toml +++ /dev/null @@ -1,9 +0,0 @@ -[build] - publish = "www" - command = "exit 0" - -[[redirects]] -from = "/*" -to = "https://github.com/netlify/gotrue" -status = 302 -force = true diff --git a/openapi.yaml b/openapi.yaml new file mode 100644 index 000000000..7996f817f --- /dev/null +++ b/openapi.yaml @@ -0,0 +1,2506 @@ +openapi: 3.0.3 +info: + version: latest + title: Supabase Auth REST API + description: |- + This is the REST API for [Supabase Auth](https://supabase.com/auth). Visit https://supabase.com/docs/guides/auth for complete documentation. + + **Notes:** + - HTTP 5XX errors are not listed for each endpoint. + These should be handled globally. Not all HTTP 5XX errors are generated from Auth, and they may serve non-JSON content. Make sure you inspect the `Content-Type` header before parsing as JSON. + - Error responses are somewhat inconsistent. + Avoid using the `msg` and HTTP status code to identify errors. HTTP 400 and 422 are used interchangeably in many APIs. + - If the server has CAPTCHA protection enabled, the verification token should be included in the request body. + - Rate limit errors are consistently raised with the HTTP 429 code. + - Enums are used only in request bodies / parameters and not in responses to ensure wide compatibility with code generators that fail to include an unknown enum case. + + **Backward compatibility:** + - Endpoints marked as _Experimental_ may change without notice. + - Endpoints marked as _Deprecated_ will be supported for at least 3 months since being marked as deprecated. + - HTTP status codes like 400, 404, 422 may change for the same underlying error condition. + + termsOfService: https://supabase.com/terms + contact: + name: Ask a question about this API + url: https://github.com/supabase/supabase/discussions + license: + name: MIT License + url: https://github.com/supabase/auth/blob/master/LICENSE +externalDocs: + description: Learn more about Supabase Auth + url: https://supabase.com/docs/guides/auth/overview +servers: + - url: "https://{project}.supabase.co/auth/v1" + variables: + project: + description: > + Your Supabase project ID. + default: abcdefghijklmnopqrst +tags: + - name: auth + description: APIs for authentication and authorization. + - name: user + description: APIs used by a user to manage their account. + - name: oauth + description: APIs for dealing with OAuth flows. + - name: sso + description: APIs for authenticating using SSO providers (SAML). (Experimental.) + - name: saml + description: SAML 2.0 Endpoints. (Experimental.) + - name: admin + description: Administration APIs requiring elevated access. + - name: general + description: General APIs. +paths: + /token: + post: + summary: Issues access and refresh tokens based on grant type. + tags: + - auth + parameters: + - name: grant_type + in: query + required: true + description: > + - What grant type should be used to issue an access and refresh token. Note: `id_token` is only offered in experimental mode. + + - CAPTCHA protection does not apply on the `refresh_token` grant flow. + + - Using `password` is akin to a user signing in. + + - `pkce` is used for exchanging the authorization code for a pair of access and refresh tokens. + schema: + type: string + enum: + - password + - refresh_token + - id_token + - pkce + - web3 + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + examples: + grant_type=password: + value: + email: user@example.com + password: password1 + grant_type=refresh_token: + value: + refresh_token: 4nYUCw0wZR_DNOTSDbSGMQ + grant_type=pkce: + value: + auth_code: 009e5066-fc11-4eca-8c8c-6fd82aa263f2 + code_verifier: ktPNXpR65N6JtgzQA8_5HHtH6PBSAahMNoLKRzQEa0Tzgl.vdV~b6lPk004XOd.4lR0inCde.NoQx5K63xPfzL8o7tJAjXncnhw5Niv9ycQ.QRV9JG.y3VapqbgLfIrJ + grant_type=web3: + value: + message: "example.com wants you to sign in with your Solana account:\n0x1234...5678\n\nSign in with Solana\n\nURI: https://example.com\nVersion: 1\nNonce: abc123def456\nIssued At: 2023-09-19T12:00:00Z" + signature: "base64_encoded_signature_string" + chain: "solana" + schema: + type: object + description: |- + For the refresh token flow, supply only `refresh_token`. + For the email/phone with password flow, supply `email`, `phone` and `password` with an optional `gotrue_meta_security`. + For the OIDC ID token flow, supply `id_token`, `nonce`, `provider`, `client_id`, `issuer` with an optional `gotrue_meta_security`. + For the Web3 flow, supply `message`, `signature`, and `chain`. + properties: + refresh_token: + type: string + password: + type: string + email: + type: string + format: email + phone: + type: string + format: phone + id_token: + type: string + access_token: + type: string + description: Provide only when `grant_type` is `id_token` and the provided ID token requires the presence of an access token to be accepted (usually by having an `at_hash` claim). + nonce: + type: string + provider: + type: string + enum: + - google + - apple + - azure + - facebook + - keycloak + client_id: + type: string + issuer: + type: string + description: If `provider` is `azure` then you can specify any Azure OIDC issuer string here, which will be used for verification. + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + auth_code: + type: string + format: uuid + code_verifier: + type: string + message: + type: string + description: | + Signed message for Web3 authentication following the Sign in with Solana standard. Must include: `Issued At`, `URI`, `Version`. + signature: + type: string + description: The signature of the message for Web3 authentication encoded as Base64 or Base64-URL. + chain: + type: string + description: What blockchain is the Web3 message and signature for. + enum: + - solana + example: solana + responses: + 200: + description: > + An access and refresh token have been successfully issued. + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/ForbiddenResponse" + 403: + $ref: "#/components/responses/UnauthorizedResponse" + 500: + $ref: "#/components/responses/InternalServerErrorResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /logout: + post: + summary: Logs out a user. + tags: + - auth + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: scope + in: query + description: > + (Optional.) Determines how the user should be logged out. When `global` is used, the user is logged out from all active sessions. When `local` is used, the user is logged out from the current session. When `others` is used, the user is logged out from all other sessions except the current one. Clients should remove stored access and refresh tokens except when `others` is used. + schema: + type: string + enum: + - global + - local + - others + responses: + 204: + description: No content returned on successful logout. + 401: + $ref: "#/components/responses/UnauthorizedResponse" + + /verify: + get: + summary: Authenticate by verifying the possession of a one-time token. Usually for use as clickable links. + tags: + - auth + parameters: + - name: token + in: query + required: true + schema: + type: string + - name: type + in: query + required: true + schema: + type: string + enum: + - signup + - invite + - recovery + - magiclink + - email_change + - name: redirect_to + in: query + description: > + (Optional) URL to redirect back into the app on after verification completes successfully. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + schema: + type: string + format: uri + security: + - APIKeyAuth: [] + responses: + 302: + $ref: "#/components/responses/AccessRefreshTokenRedirectResponse" + post: + summary: Authenticate by verifying the possession of a one-time token. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + type: + type: string + enum: + - signup + - recovery + - invite + - magiclink + - email_change + - sms + - phone_change + token: + type: string + token_hash: + type: string + description: > + The hashed value of token. Applicable only if used with `type` and nothing else. + email: + type: string + format: email + description: > + Applicable only if `type` is with regards to an email address. + phone: + type: string + format: phone + description: > + Applicable only if `type` is with regards to an phone number. + redirect_to: + type: string + format: uri + description: > + (Optional) URL to redirect back into the app on after verification completes successfully. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + + responses: + 200: + description: An access and refresh token. + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /authorize: + get: + summary: Redirects to an external OAuth provider. Usually for use as clickable links. + tags: + - oauth + security: + - APIKeyAuth: [] + parameters: + - name: provider + in: query + description: Name of the OAuth provider. + example: google + required: true + schema: + type: string + pattern: "^[a-zA-Z0-9]+$" + - name: scopes + in: query + required: true + description: Space separated list of OAuth scopes to pass on to `provider`. + schema: + type: string + pattern: "[^ ]+( +[^ ]+)*" + - name: invite_token + in: query + description: (Optional) A token representing a previous invitation of the user. A successful sign-in with OAuth will mark the invitation as completed. + schema: + type: string + - name: redirect_to + in: query + description: > + (Optional) URL to redirect back into the app on after OAuth sign-in completes successfully or not. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + schema: + type: string + format: uri + - name: code_challenge_method + in: query + description: (Optional) Method used to encrypt the verifier. Can be `plain` (no transformation) or `s256` (where SHA-256 is used). It is always recommended that `s256` is used. + schema: + type: string + enum: + - plain + - s256 + responses: + 302: + $ref: "#/components/responses/OAuthAuthorizeRedirectResponse" + + /signup: + post: + summary: Signs a user up. + description: > + Creates a new user. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + examples: + "email+password": + value: + email: user@example.com + password: password1 + "phone+password": + value: + phone: "+1234567890" + password: password1 + "phone+password+whatsapp": + value: + phone: "+1234567890" + password: password1 + channel: whatsapp + "email+password+pkce": + value: + email: user@example.com + password: password1 + code_challenge_method: s256 + code_challenge: elU6u5zyqQT2f92GRQUq6PautAeNDf4DQPayyR0ek_c& + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + channel: + type: string + enum: + - sms + - whatsapp + password: + type: string + data: + type: object + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: > + A user already exists and is not confirmed (in which case a user object is returned). A user did not exist and is signed up. If email or phone confirmation is enabled, returns a user object. If confirmation is disabled, returns an access token and refresh token response. + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/AccessTokenResponseSchema" + - $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /recover: + post: + summary: Request password recovery. + description: > + Users that have forgotten their password can have it reset with this API. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + format: email + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: A recovery email has been sent to the address. An empty JSON object is returned. To obfuscate whether such an email address already exists in the system this response is sent regardless whether the address exists or not. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /resend: + post: + summary: Resends a one-time password (OTP) through email or SMS. + description: > + Allows a user to resend an existing signup, sms, email_change or phone_change OTP. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + description: > + Applicable only if `type` is with regards to an email address. + phone: + type: string + format: phone + description: > + Applicable only if `type` is with regards to an phone number. + type: + type: string + enum: + - signup + - email_change + - sms + - phone_change + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: A One-Time Password was sent to the email or phone. To obfuscate whether such an address or number already exists in the system this response is sent in both cases. + content: + application/json: + schema: + type: object + properties: + message_id: + type: string + description: Unique ID of the message as reported by the SMS sending provider. Useful for tracking deliverability problems. + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address or phone number. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /magiclink: + post: + summary: Authenticate a user by sending them a magic link. + description: > + A magic link is a special type of URL that includes a One-Time Password. When a user visits this link in a browser they are immediately authenticated. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + format: email + data: + type: object + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: A recovery email has been sent to the address. An empty JSON object is returned. To obfuscate whether such an email address already exists in the system this response is sent regardless whether the address exists or not. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /otp: + post: + summary: Authenticate a user by sending them a One-Time Password over email or SMS. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + channel: + type: string + enum: + - sms + - whatsapp + create_user: + type: boolean + data: + type: object + code_challenge_method: + type: string + enum: + - s256 + - plain + code_challenge: + type: string + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: A One-Time Password was sent to the email or phone. To obfuscate whether such an address or number already exists in the system this response is sent in both cases. + content: + application/json: + schema: + type: object + properties: + message_id: + type: string + description: Unique ID of the message as reported by the SMS sending provider. Useful for tracking deliverability problems. + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email or phone number. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /user: + get: + summary: Fetch the latest user account information. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + responses: + 200: + description: User's account information. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + put: + summary: Update certain properties of the current user account. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + password: + type: string + nonce: + type: string + data: + type: object + app_metadata: + type: object + channel: + type: string + enum: + - sms + - whatsapp + responses: + 200: + description: User's updated account information. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /user/identities/authorize: + get: + summary: Links an OAuth identity to an existing user. Redirects to an external OAuth provider. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: provider + in: query + description: Name of the OAuth provider. + example: google + required: true + schema: + type: string + pattern: "^[a-zA-Z0-9]+$" + - name: scopes + in: query + required: true + description: Space separated list of OAuth scopes to pass on to `provider`. + schema: + type: string + pattern: "[^ ]+( +[^ ]+)*" + - name: redirect_to + in: query + description: > + (Optional) URL to redirect back into the app on after OAuth sign-in completes successfully or not. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + schema: + type: string + format: uri + - name: code_challenge_method + in: query + description: (Optional) Method used to encrypt the verifier. Can be `plain` (no transformation) or `s256` (where SHA-256 is used). It is always recommended that `s256` is used. + schema: + type: string + enum: + - plain + - s256 + responses: + 302: + $ref: "#/components/responses/OAuthAuthorizeRedirectResponse" + + /user/identities/{identityId}: + parameters: + - name: identityId + in: path + required: true + schema: + type: string + format: uuid + delete: + summary: Unlinks an identity from the current user. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + responses: + 200: + description: User's account data. + content: + application/json: + schema: + type: object + 401: + description: The user is not authenticated. + content: + application/json: + schema: + $ref: "#/components/responses/UnauthorizedResponse" + examples: + example: + summary: no_authorization + value: + error_code: no_authorization + 403: + description: Forbidden error + content: + application/json: + schema: + $ref: "#/components/responses/ForbiddenResponse" + examples: + example: + summary: bad_jwt + value: + error_code: bad_jwt + example1: + summary: unexpected_audience + value: + error_code: unexpected_audience + 404: + description: Not found error + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + example: + summary: manual_linking_disabled + value: + error_code: manual_linking_disabled + example1: + summary: validation_failed + value: + error_code: validation_failed + 422: + description: Unprocessable entity + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + example: + summary: single_identity_not_deletable + value: + error_code: single_identity_not_deletable + example1: + summary: identity_already_exists + value: + error_code: identity_already_exists + example2: + summary: identity_not_found + value: + error_code: identity_not_found + example3: + summary: email_conflict_identity_not_deletable + value: + error_code: email_conflict_identity_not_deletable + + /reauthenticate: + post: + summary: Reauthenticates the possession of an email or phone number for the purpose of password change. + description: > + For a password to be changed on a user account, the user's email or phone number needs to be confirmed before they are allowed to set a new password. This requirement is configurable. This API sends a confirmation email or SMS message. A nonce in this message can be provided in `PUT /user` to change the password on the account. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + responses: + 200: + description: A One-Time Password was sent to the user's email or phone. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors: + post: + summary: Begin enrolling a new factor for MFA. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - factor_type + properties: + factor_type: + type: string + enum: + - totp + - phone + - webauthn + friendly_name: + type: string + issuer: + type: string + format: uri + phone: + type: string + format: phone + responses: + 200: + description: > + A new factor was created in the unverified state. Call `POST /factors/{factorId}/verify' to verify it. + content: + application/json: + schema: + type: object + properties: + id: + type: string + type: + type: string + enum: + - totp + - phone + - webauthn + totp: + type: object + properties: + qr_code: + type: string + secret: + type: string + uri: + type: string + phone: + type: string + format: phone + + 400: + $ref: "#/components/responses/BadRequestResponse" + + /factors/{factorId}/challenge: + post: + summary: Create a new challenge for a MFA factor. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + requestBody: + content: + application/json: + schema: + type: object + properties: + channel: + type: string + enum: + - sms + - whatsapp + + responses: + 200: + description: > + A new challenge was generated for the factor. Use `POST /factors/{factorId}/verify` to verify the challenge. + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/TOTPPhoneChallengeResponse' + - $ref: '#/components/schemas/WebAuthnChallengeResponse' + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors/{factorId}/verify: + post: + summary: Verify a challenge on a factor. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + requestBody: + content: + application/json: + schema: + type: object + required: + - challenge_id + properties: + challenge_id: + type: string + format: uuid + code: + type: string + responses: + 200: + description: > + This challenge has been verified. Client libraries should replace their stored access and refresh tokens with the ones provided in this response. These new credentials have an increased Authenticator Assurance Level (AAL). + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors/{factorId}: + delete: + summary: Remove a MFA factor from a user. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + responses: + 200: + description: > + This MFA factor is removed (unenrolled) and cannot be used for increasing the AAL level of user's sessions. Client libraries should use the `POST /token?grant_type=refresh_token` endpoint to get a new access and refresh token with a decreased AAL. + content: + application/json: + schema: + type: object + properties: + id: + type: string + format: uuid + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + 400: + $ref: "#/components/responses/BadRequestResponse" + + /callback: + get: + summary: Redirects OAuth flow errors to the frontend app. + description: > + When an OAuth sign-in flow fails for any reason, the error message needs to be delivered to the frontend app requesting the flow. This callback delivers the errors as `error` and `error_description` query params. Usually this request is not called directly. + tags: + - oauth + security: + - APIKeyAuth: [] + responses: + 302: + $ref: "#/components/responses/OAuthCallbackRedirectResponse" + post: + summary: Redirects OAuth flow errors to the frontend app. + description: > + When an OAuth sign-in flow fails for any reason, the error message needs to be delivered to the frontend app requesting the flow. This callback delivers the errors as `error` and `error_description` query params. Usually this request is not called directly. + tags: + - oauth + responses: + 302: + $ref: "#/components/responses/OAuthCallbackRedirectResponse" + + /sso: + post: + summary: Initiate a Single-Sign On flow. + tags: + - sso + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + domain: + type: string + format: hostname + description: Email address domain used to identify the SSO provider. + provider_id: + type: string + format: uuid + example: 40451fc2-4997-429c-bf7f-cc6f33c788e6 + redirect_to: + type: string + format: uri + skip_http_redirect: + type: boolean + description: Set to `true` if the response to this request should not be a HTTP 303 redirect -- useful for browser-based applications. + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueSecurity" + responses: + 200: + description: > + Returned only when `skip_http_redirect` is `true` and the SSO provider could be identified from the `provider_id` or `domain`. Client libraries should use the returned URL to redirect or open a browser. + content: + application/json: + schema: + type: object + properties: + url: + type: string + format: uri + 303: + description: > + Returned only when `skip_http_redirect` is `false` or not present and the SSO provider could be identified from the `provider_id` or `domain`. Client libraries should follow the redirect. 303 is used instead of 302 because the request should be executed with a `GET` verb. + headers: + Location: + schema: + type: string + format: uri + 400: + $ref: "#/components/responses/BadRequestResponse" + 404: + description: > + Returned when the SSO provider could not be identified. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /saml/metadata: + get: + summary: Returns the SAML 2.0 Metadata XML. + description: > + The metadata XML can be downloaded or used for the SAML 2.0 Metadata URL discovery mechanism. This URL is the SAML 2.0 EntityID of the Service Provider implemented by this server. + tags: + - saml + security: + - APIKeyAuth: [] + parameters: + - name: download + in: query + description: > + If set to `true` will add a `Content-Disposition` header to the response which will trigger a download dialog on the browser. + schema: + type: boolean + responses: + 200: + description: > + A valid SAML 2.0 Metadata XML document. Should be cached according to the `Cache-Control` header and/or caching data specified in the document itself. + headers: + Content-Disposition: + description: > + Present if `download=true`, which triggers the browser to show a donwload dialog. + schema: + type: string + example: attachment; filename="metadata.xml" + Cache-Control: + description: > + Should be parsed and obeyed to avoid putting strain on the server. + schema: + type: string + example: public, max-age=600 + + /saml/acs: + post: + summary: SAML 2.0 Assertion Consumer Service (ACS) endpoint. + description: > + Implements the SAML 2.0 Assertion Consumer Service (ACS) endpoint supporting the POST and Artifact bindings. + tags: + - saml + security: [] + parameters: + - name: RelayState + in: query + schema: + oneOf: + - type: string + format: uri + description: URL to take the user to after the ACS has been verified. Often sent by Identity Provider initiated login requests. + - type: string + format: uuid + description: UUID of the SAML Relay State stored in the database, used to identify the Service Provider initiated login request. + - name: SAMLArt + in: query + description: > + See the SAML 2.0 ACS specification. Cannot be used without a UUID `RelayState` parameter. + schema: + type: string + - name: SAMLResponse + in: query + description: > + See the SAML 2.0 ACS specification. Must be present unless `SAMLArt` is specified. If `RelayState` is not a UUID, the SAML Response is unpacked and the identity provider is identified from the response. + schema: + type: string + responses: + 302: + $ref: "#/components/responses/AccessRefreshTokenRedirectResponse" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /invite: + post: + summary: Invite a user by email. + description: > + Sends an invitation email which contains a link that allows the user to sign-in. + tags: + - admin + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + data: + type: object + responses: + 200: + description: An invitation has been sent to the user. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: User already exists and has confirmed their address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/generate_link: + post: + summary: Generate a link to send in an email message. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - type + - email + properties: + type: + type: string + enum: + - magiclink + - signup + - recovery + - email_change_current + - email_change_new + email: + type: string + format: email + new_email: + type: string + format: email + password: + type: string + data: + type: object + redirect_to: + type: string + format: uri + responses: + 200: + description: User profile and generated link information. + content: + application/json: + schema: + type: object + additionalProperties: true + properties: + action_link: + type: string + format: uri + email_otp: + type: string + hashed_token: + type: string + verification_type: + type: string + redirect_to: + type: string + format: uri + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 422: + description: > + Has multiple meanings: + - User already exists + - Provided password does not meet minimum criteria + - Secure email change not enabled + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/audit: + get: + summary: Fetch audit log events. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + parameters: + - name: page + in: query + schema: + type: integer + minimum: 1 + default: 1 + - name: per_page + in: query + schema: + type: integer + minimum: 1 + default: 50 + responses: + 200: + description: List of audit logs. + content: + application/json: + schema: + type: array + items: + type: object + properties: + id: + type: string + format: uuid + payload: + type: object + properties: + actor_id: + type: string + actor_via_sso: + type: boolean + description: Whether the actor used a SSO protocol (like SAML 2.0 or OIDC) to authenticate. + actor_username: + type: string + actor_name: + type: string + traits: + type: object + action: + type: string + description: |- + Usually one of these values: + - login + - logout + - invite_accepted + - user_signedup + - user_invited + - user_deleted + - user_modified + - user_recovery_requested + - user_reauthenticate_requested + - user_confirmation_requested + - user_repeated_signup + - user_updated_password + - token_revoked + - token_refreshed + - generate_recovery_codes + - factor_in_progress + - factor_unenrolled + - challenge_created + - verification_attempted + - factor_deleted + - recovery_codes_deleted + - factor_updated + - mfa_code_login + log_type: + type: string + description: |- + Usually one of these values: + - account + - team + - token + - user + - factor + - recovery_codes + created_at: + type: string + format: date-time + ip_address: + type: string + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/users: + get: + summary: Fetch a listing of users. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + parameters: + - name: page + in: query + schema: + type: integer + minimum: 1 + default: 1 + - name: per_page + in: query + schema: + type: integer + minimum: 1 + default: 50 + responses: + 200: + description: A page of users. + content: + application/json: + schema: + type: object + properties: + aud: + type: string + deprecated: true + users: + type: array + items: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/users/{userId}: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: Fetch user account data for a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's account data. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + put: + summary: Update user's account data. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + responses: + 200: + description: User's account data was updated. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Delete a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's account data. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/users/{userId}/factors: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: List all of the MFA factors for a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's MFA factors. + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/users/{userId}/factors/{factorId}: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + - name: factorId + in: path + required: true + schema: + type: string + format: uuid + put: + summary: Update a user's MFA factor. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + responses: + 200: + description: User's MFA factor. + content: + application/json: + schema: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user and/or factor. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Remove a user's MFA factor. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's MFA factor. + content: + application/json: + schema: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user and/or factor. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/sso/providers: + get: + summary: Fetch a list of all registered SSO providers. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: A list of all providers. + content: + application/json: + schema: + type: object + properties: + items: + type: array + items: + $ref: "#/components/schemas/SSOProviderSchema" + post: + summary: Register a new SSO provider. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - type + properties: + type: + type: string + enum: + - saml + metadata_url: + type: string + format: uri + metadata_xml: + type: string + domains: + type: array + items: + type: string + format: hostname + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + responses: + 200: + description: SSO provider was created. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/sso/providers/{ssoProviderId}: + parameters: + - name: ssoProviderId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: Fetch SSO provider details. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: SSO provider exists with these details. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + put: + summary: Update details about a SSO provider. + description: > + You can only update only one of `metadata_url` or `metadata_xml` at once. The SAML Metadata represented by these updates must advertize the same Identity Provider EntityID. Do not include the `domains` or `attribute_mapping` property to keep the existing database values. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + metadata_url: + type: string + format: uri + metadata_xml: + type: string + domains: + type: array + items: + type: string + pattern: "[a-z0-9-]+([.][a-z0-9-]+)*" + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + responses: + 200: + description: SSO provider details were updated. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Remove an SSO provider. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: SSO provider was removed. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /health: + get: + summary: Service healthcheck. + description: Ping this endpoint to receive information about the health of the service. + tags: + - general + security: + - APIKeyAuth: [] + responses: + 200: + description: > + Service is healthy. + content: + application/json: + schema: + type: object + properties: + version: + type: string + example: v2.40.1 + name: + type: string + example: GoTrue + description: + type: string + example: GoTrue is a user registration and authentication API + + 500: + description: > + Service is not healthy. Retriable with exponential backoff. + 502: + description: > + Service is not healthy: infrastructure issue. Usually not retriable. + 503: + description: > + Service is not healthy: infrastrucutre issue. Retriable with exponential backoff. + 504: + description: > + Service is not healthy: request timed out. Retriable with exponential backoff. + + /settings: + get: + summary: Retrieve some of the public settings of the server. + description: > + Use this endpoint to configure parts of any authentication UIs depending on the configured settings. + tags: + - general + security: + - APIKeyAuth: [] + responses: + 200: + description: > + Currently applicable settings of the server. + content: + application/json: + schema: + type: object + properties: + disable_signup: + type: boolean + example: false + description: Whether new accounts can be created. (Valid for all providers.) + mailer_autoconfirm: + type: boolean + example: false + description: Whether new email addresses need to be confirmed before sign-in is possible. + phone_autoconfirm: + type: boolean + example: false + description: Whether new phone numbers need to be confirmed before sign-in is possible. + sms_provider: + type: string + optional: true + example: twilio + description: Which SMS provider is being used to send messages to phone numbers. + saml_enabled: + type: boolean + example: true + description: Whether SAML is enabled on this API server. Defaults to false. + external: + type: object + description: Which external identity providers are enabled. + example: + github: true + apple: true + email: true + phone: true + patternProperties: + "[a-zA-Z0-9]+": + type: boolean + +components: + securitySchemes: + UserAuth: + type: http + scheme: bearer + description: > + An access token in the form of a JWT issued by this server. + + AdminAuth: + type: http + scheme: bearer + description: > + A special admin JWT. + + APIKeyAuth: + type: apiKey + in: header + name: apikey + description: > + When deployed on Supabase, this server requires an `apikey` header containing a valid Supabase-issued API key to call any endpoint. + + schemas: + GoTrueSecurity: + type: object + description: > + Use this property to pass a CAPTCHA token only if you have enabled CAPTCHA protection. + properties: + captcha_token: + type: string + + ErrorSchema: + type: object + properties: + error: + type: string + description: |- + Certain responses will contain this property with the provided values. + + Usually one of these: + - invalid_request + - unauthorized_client + - access_denied + - server_error + - temporarily_unavailable + - unsupported_otp_type + error_description: + type: string + description: > + Certain responses that have an `error` property may have this property which describes the error. + code: + type: integer + description: > + The HTTP status code. Usually missing if `error` is present. + example: 400 + msg: + type: string + description: > + A basic message describing the problem with the request. Usually missing if `error` is present. + error_code: + type: string + description: > + A short code used to describe the class of error encountered. + weak_password: + type: object + description: > + Only returned on the `/signup` endpoint if the password used is too weak. Inspect the `reasons` and `msg` property to identify the causes. + properties: + reasons: + type: array + items: + type: string + enum: + - length + - characters + - pwned + + UserSchema: + type: object + description: Object describing the user related to the issued access and refresh tokens. + properties: + id: + type: string + format: uuid + aud: + type: string + deprecated: true + role: + type: string + email: + type: string + description: User's primary contact email. In most cases you can uniquely identify a user by their email address, but not in all cases. + email_confirmed_at: + type: string + format: date-time + phone: + type: string + format: phone + description: User's primary contact phone number. In most cases you can uniquely identify a user by their phone number, but not in all cases. + phone_confirmed_at: + type: string + format: date-time + confirmation_sent_at: + type: string + format: date-time + confirmed_at: + type: string + format: date-time + recovery_sent_at: + type: string + format: date-time + new_email: + type: string + format: email + email_change_sent_at: + type: string + format: date-time + new_phone: + type: string + format: phone + phone_change_sent_at: + type: string + format: date-time + reauthentication_sent_at: + type: string + format: date-time + last_sign_in_at: + type: string + format: date-time + app_metadata: + type: object + user_metadata: + type: object + factors: + type: array + items: + $ref: "#/components/schemas/MFAFactorSchema" + identities: + type: array + items: + $ref: "#/components/schemas/IdentitySchema" + banned_until: + type: string + format: date-time + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + deleted_at: + type: string + format: date-time + is_anonymous: + type: boolean + + SAMLAttributeMappingSchema: + type: object + properties: + keys: + type: object + patternProperties: + ".+": + type: object + properties: + name: + type: string + names: + type: array + items: + type: string + default: + oneOf: + - type: string + - type: number + - type: boolean + - type: object + + SSOProviderSchema: + type: object + properties: + id: + type: string + format: uuid + sso_domains: + type: array + items: + type: object + properties: + domain: + type: string + format: hostname + saml: + type: object + properties: + entity_id: + type: string + metadata_xml: + type: string + metadata_url: + type: string + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + + AccessTokenResponseSchema: + type: object + properties: + access_token: + type: string + description: A valid JWT that will expire in `expires_in` seconds. + refresh_token: + type: string + description: An opaque string that can be used once to obtain a new access and refresh token. + token_type: + type: string + description: What type of token this is. Only `bearer` returned, may change in the future. + expires_in: + type: integer + description: Number of seconds after which the `access_token` should be renewed by using the refresh token with the `refresh_token` grant type. + expires_at: + type: integer + description: UNIX timestamp after which the `access_token` should be renewed by using the refresh token with the `refresh_token` grant type. + weak_password: + type: object + description: Only returned on the `/token?grant_type=password` endpoint. When present, it indicates that the password used is weak. Inspect the `reasons` and/or `message` properties to identify why. + properties: + reasons: + type: array + items: + type: string + enum: + - length + - characters + - pwned + message: + type: string + user: + $ref: "#/components/schemas/UserSchema" + + MFAFactorSchema: + type: object + description: Represents a MFA factor. + properties: + id: + type: string + format: uuid + status: + type: string + description: |- + Usually one of: + - verified + - unverified + friendly_name: + type: string + factor_type: + type: string + description: |- + Usually one of: + - totp + - phone + - webauthn + web_authn_credential: + type: string + phone: + type: string + format: phone + nullable: true + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + last_challenged_at: + type: string + format: date-time + nullable: true + + + IdentitySchema: + type: object + properties: + identity_id: + type: string + format: uuid + id: + type: string + format: uuid + user_id: + type: string + format: uuid + identity_data: + type: object + provider: + type: string + last_sign_in_at: + type: string + format: date-time + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + email: + type: string + format: email + TOTPPhoneChallengeResponse: + type: object + required: + - id + - type + - expires_at + properties: + id: + type: string + format: uuid + example: 14c1560e-2749-4522-bb62-d1458451830a + description: ID of the challenge. + type: + type: string + enum: [totp, phone] + description: Type of the challenge. + expires_at: + type: integer + example: 1674840917 + description: UNIX seconds of the timestamp past which the challenge should not be verified. + + WebAuthnChallengeResponse: + type: object + required: + - id + - type + - expires_at + - credential_options + properties: + id: + type: string + format: uuid + example: 14c1560e-2749-4522-bb62-d1458451830a + description: ID of the challenge. + type: + type: string + enum: [webauthn] + description: Type of the challenge. + expires_at: + type: integer + example: 1674840917 + description: UNIX seconds of the timestamp past which the challenge should not be verified. + credential_request_options: + $ref: '#/components/schemas/CredentialRequestOptions' + credential_creation_options: + $ref: '#/components/schemas/CredentialCreationOptions' + + CredentialAssertion: + type: object + description: WebAuthn credential assertion options + required: + - challenge + - rpId + - allowCredentials + - timeout + properties: + challenge: + type: string + description: A random challenge generated by the server, base64url encoded + example: "Y2hhbGxlbmdlAyv-5P0kw1SG-OxhLbSHpRLdWaVR1w" + rpId: + type: string + description: The relying party's identifier (usually the domain name) + example: "example.com" + allowCredentials: + type: array + description: List of credentials acceptable for this authentication + items: + type: object + required: + - id + - type + properties: + id: + type: string + description: Credential ID, base64url encoded + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + type: + type: string + enum: [public-key] + description: Type of the credential + timeout: + type: integer + description: Time (in milliseconds) that the user has to respond to the authentication prompt + example: 60000 + userVerification: + type: string + enum: [required, preferred, discouraged] + description: The relying party's requirements for user verification + default: preferred + extensions: + type: object + description: Additional parameters requesting additional processing by the client + status: + type: string + enum: [ok, failed] + description: Status of the credential assertion + errorMessage: + type: string + description: Error message if the assertion failed + userHandle: + type: string + description: User handle, base64url encoded + authenticatorAttachment: + type: string + enum: [platform, cross-platform] + description: Type of authenticator to use + + CredentialRequest: + type: object + description: WebAuthn credential request (for the response from the client) + required: + - id + - rawId + - type + - response + properties: + id: + type: string + description: Base64url encoding of the credential ID + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + rawId: + type: string + description: Base64url encoding of the credential ID (same as id) + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + type: + type: string + enum: [public-key] + description: Type of the credential + response: + type: object + required: + - clientDataJSON + - authenticatorData + - signature + - userHandle + properties: + clientDataJSON: + type: string + description: Base64url encoding of the client data + example: "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiY2hhbGxlbmdlIiwib3JpZ2luIjoiaHR0cHM6Ly9leGFtcGxlLmNvbSJ9" + authenticatorData: + type: string + description: Base64url encoding of the authenticator data + example: "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2MBAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + signature: + type: string + description: Base64url encoding of the signature + example: "MEUCIQCx5cJVAB3kGP6bqCIoAV6CkBpVAf8rcx0WSZ22fIxXvQIgCKFt9pEu1vK8U4JKYTfn6tGjvGNfx2F4uXrHSXlefvM" + userHandle: + type: string + description: Base64url encoding of the user handle + example: "MQ" + clientExtensionResults: + type: object + description: Client extension results + + CredentialRequestOptions: + type: object + description: Options for requesting an assertion + properties: + challenge: + type: string + format: byte + description: A challenge to be signed by the authenticator + timeout: + type: integer + description: Time (in milliseconds) that the caller is willing to wait for the call to complete + rpId: + type: string + description: Relying Party ID + allowCredentials: + type: array + items: + $ref: '#/components/schemas/PublicKeyCredentialDescriptor' + userVerification: + type: string + enum: [required, preferred, discouraged] + description: User verification requirement + + CredentialCreationOptions: + type: object + description: Options for creating a new credential + properties: + rp: + type: object + properties: + id: + type: string + name: + type: string + user: + $ref: '#/components/schemas/UserSchema' + + challenge: + type: string + format: byte + description: A challenge to be signed by the authenticator + pubKeyCredParams: + type: array + items: + type: object + properties: + type: + type: string + enum: [public-key] + alg: + type: integer + timeout: + type: integer + description: Time (in milliseconds) that the caller is willing to wait for the call to complete + excludeCredentials: + type: array + items: + $ref: '#/components/schemas/PublicKeyCredentialDescriptor' + authenticatorSelection: + type: object + properties: + authenticatorAttachment: + type: string + enum: [platform, cross-platform] + requireResidentKey: + type: boolean + userVerification: + type: string + enum: [required, preferred, discouraged] + attestation: + type: string + enum: [none, indirect, direct] + description: Preferred attestation conveyance + + PublicKeyCredentialDescriptor: + type: object + properties: + type: + type: string + enum: [public-key] + id: + type: string + format: byte + description: Credential ID + transports: + type: array + items: + type: string + enum: [usb, nfc, ble, internal] + + responses: + OAuthCallbackRedirectResponse: + description: > + HTTP Redirect to a URL containing the `error` and `error_description` query parameters which should be shown to the user requesting the OAuth sign-in flow. + headers: + Location: + description: > + URL containing the `error` and `error_description` query parameters. + schema: + type: string + format: uri + example: https://example.com/?error=server_error&error_description=User%20does%20not%20exist. + + OAuthAuthorizeRedirectResponse: + description: > + HTTP Redirect to the OAuth identity provider's authorization URL. + headers: + Location: + description: > + URL to which the user agent should redirect (or open in a browser for mobile apps). + schema: + type: string + format: uri + + RateLimitResponse: + description: > + HTTP Too Many Requests response, when a rate limiter has been breached. + content: + application/json: + schema: + type: object + properties: + code: + type: integer + example: 429 + msg: + type: string + description: A basic message describing the rate limit breach. Do not use as an error code identifier. + example: Too many requests. Please try again in a few seconds. + + BadRequestResponse: + description: > + HTTP Bad Request response. Can occur if the passed in JSON cannot be unmarshalled properly or when CAPTCHA verification was not successful. In certain cases can also occur when features are disabled on the server (e.g. sign ups). It may also mean that the operation failed due to some constraint not being met (such a user already exists for example). + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + UnauthorizedResponse: + description: > + HTTP Unauthorizred response. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + ForbiddenResponse: + description: > + HTTP Forbidden response. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + InternalServerErrorResponse: + description: > + HTTP Internal Server Error. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + AccessRefreshTokenRedirectResponse: + description: > + HTTP See Other redirect response where `Location` is a specially formatted URL that includes an `access_token`, `refresh_token`, `expires_in` as URL query encoded values in the URL fragment (anything after `#`). These values are encoded in the fragment as this value is only visible to the browser handling the redirect and is not sent to the server. + headers: + Location: + schema: + type: string + format: uri + example: https://example.com/#access_token=...&refresh_token=...&expires_in=... diff --git a/security/hcaptcha.go b/security/hcaptcha.go deleted file mode 100644 index 4017f8571..000000000 --- a/security/hcaptcha.go +++ /dev/null @@ -1,95 +0,0 @@ -package security - -import ( - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -type GotrueRequest struct { - Security GotrueSecurity `json:"gotrue_meta_security"` -} - -type GotrueSecurity struct { - Token string `json:"hcaptcha_token"` -} - -type VerificationResponse struct { - Success bool `json:"success"` - ErrorCodes []string `json:"error-codes"` - Hostname string `json:"hostname"` -} - -type VerificationResult int - -const ( - UserRequestFailed VerificationResult = iota - VerificationProcessFailure - SuccessfullyVerified -) - -var Client *http.Client - -func init() { - // TODO (darora): make timeout configurable - Client = &http.Client{Timeout: 10 * time.Second} -} - -func VerifyRequest(r *http.Request, secretKey string) (VerificationResult, error) { - res := GotrueRequest{} - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - return UserRequestFailed, err - } - r.Body.Close() - // re-init body so downstream route handlers don't get borked - r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) - - jsonDecoder := json.NewDecoder(bytes.NewBuffer(bodyBytes)) - err = jsonDecoder.Decode(&res) - if err != nil || strings.TrimSpace(res.Security.Token) == "" { - return UserRequestFailed, errors.Wrap(err, "couldn't decode captcha info") - } - clientIP := strings.Split(r.RemoteAddr, ":")[0] - return verifyCaptchaCode(res.Security.Token, secretKey, clientIP) -} - -func verifyCaptchaCode(token string, secretKey string, clientIP string) (VerificationResult, error) { - data := url.Values{} - data.Set("secret", secretKey) - data.Set("response", token) - data.Set("remoteip", clientIP) - // TODO (darora): pipe through sitekey - - r, err := http.NewRequest("POST", "https://hcaptcha.com/siteverify", strings.NewReader(data.Encode())) - if err != nil { - return VerificationProcessFailure, errors.Wrap(err, "couldn't initialize request object for hcaptcha check") - } - r.Header.Add("Content-Type", "application/x-www-form-urlencoded") - r.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - res, err := Client.Do(r) - if err != nil { - return VerificationProcessFailure, errors.Wrap(err, "failed to verify hcaptcha token") - } - verResult := VerificationResponse{} - defer res.Body.Close() - decoder := json.NewDecoder(res.Body) - err = decoder.Decode(&verResult) - if err != nil { - return VerificationProcessFailure, errors.Wrap(err, "failed to decode hcaptcha response") - } - logrus.WithField("result", verResult).Info("obtained hcaptcha verification result") - if !verResult.Success { - return UserRequestFailed, fmt.Errorf("user request suppressed by hcaptcha") - } - return SuccessfullyVerified, nil -} diff --git a/storage/dial.go b/storage/dial.go deleted file mode 100644 index e176851ea..000000000 --- a/storage/dial.go +++ /dev/null @@ -1,78 +0,0 @@ -package storage - -import ( - "net/url" - "reflect" - - _ "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/mysql" - _ "github.com/go-sql-driver/mysql" - "github.com/gobuffalo/pop/v5" - "github.com/gobuffalo/pop/v5/columns" - "github.com/netlify/gotrue/conf" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -// Connection is the interface a storage provider must implement. -type Connection struct { - *pop.Connection -} - -// Dial will connect to that storage engine -func Dial(config *conf.GlobalConfiguration) (*Connection, error) { - if config.DB.Driver == "" && config.DB.URL != "" { - u, err := url.Parse(config.DB.URL) - if err != nil { - return nil, errors.Wrap(err, "parsing db connection url") - } - config.DB.Driver = u.Scheme - } - - db, err := pop.NewConnection(&pop.ConnectionDetails{ - Dialect: config.DB.Driver, - URL: config.DB.URL, - }) - if err != nil { - return nil, errors.Wrap(err, "opening database connection") - } - if err := db.Open(); err != nil { - return nil, errors.Wrap(err, "checking database connection") - } - if logrus.StandardLogger().Level == logrus.DebugLevel { - pop.Debug = true - } - - return &Connection{db}, nil -} - -func (c *Connection) Transaction(fn func(*Connection) error) error { - if c.TX == nil { - return c.Connection.Transaction(func(tx *pop.Connection) error { - return fn(&Connection{tx}) - }) - } - return fn(c) -} - -func getExcludedColumns(model interface{}, includeColumns ...string) ([]string, error) { - sm := &pop.Model{Value: model} - st := reflect.TypeOf(model) - if st.Kind() == reflect.Ptr { - st = st.Elem() - } - - // get all columns and remove included to get excluded set - cols := columns.ForStructWithAlias(model, sm.TableName(), sm.As, sm.IDField()) - for _, f := range includeColumns { - if _, ok := cols.Cols[f]; !ok { - return nil, errors.Errorf("Invalid column name %s", f) - } - cols.Remove(f) - } - - xcols := make([]string, len(cols.Cols)) - for n := range cols.Cols { - xcols = append(xcols, n) - } - return xcols, nil -} diff --git a/storage/dial_test.go b/storage/dial_test.go deleted file mode 100644 index 8dbd1f493..000000000 --- a/storage/dial_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package storage - -import ( - "testing" - - "github.com/gofrs/uuid" - "github.com/stretchr/testify/require" -) - -type TestUser struct { - ID uuid.UUID - Role string `db:"role"` - Other string `db:"othercol"` -} - -func TestGetExcludedColumns(t *testing.T) { - u := TestUser{} - cols, err := getExcludedColumns(u, "role") - require.NoError(t, err) - require.NotContains(t, cols, "role") - require.Contains(t, cols, "othercol") -} - -func TestGetExcludedColumns_InvalidName(t *testing.T) { - u := TestUser{} - _, err := getExcludedColumns(u, "adsf") - require.Error(t, err) -} diff --git a/storage/session.go b/storage/session.go deleted file mode 100644 index e6189690d..000000000 --- a/storage/session.go +++ /dev/null @@ -1,42 +0,0 @@ -package storage - -import ( - "errors" - "net/http" - - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" - "github.com/kelseyhightower/envconfig" -) - -var sessionName = "_gotrue_session" -var Store sessions.Store - -type SessionConfig struct { - Key []byte `envconfig:"GOTRUE_SESSION_KEY"` -} - -func init() { - var sessionConfig SessionConfig - err := envconfig.Process("GOTRUE_SESSION_KEY", &sessionConfig) - if err != nil || len(sessionConfig.Key) == 0 { - sessionConfig.Key = securecookie.GenerateRandomKey(32) - } - Store = sessions.NewCookieStore(sessionConfig.Key) -} - -func StoreInSession(key string, value string, req *http.Request, res http.ResponseWriter) error { - session, _ := Store.New(req, sessionName) - session.Values[key] = value - return session.Save(req, res) -} - -func GetFromSession(key string, req *http.Request) (string, error) { - session, _ := Store.Get(req, sessionName) - value, ok := session.Values[key] - if !ok { - return "", errors.New("session could not be found for this request") - } - - return value.(string), nil -} diff --git a/tools/tools.go b/tools/tools.go new file mode 100644 index 000000000..35e0f39ee --- /dev/null +++ b/tools/tools.go @@ -0,0 +1,8 @@ +//go:build tools +// +build tools + +package main + +import ( + _ "github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen" +)