DEVELOPMENT ENVIRONMENT

~liljamo/emerwen-web

ref: cb21a041ada3070c6dca5418e703cfd6af0ae3b8 emerwen-web/internal/handlers/oauth2.go -rw-r--r-- 1.7 KiB
cb21a041Jonni Liljamo feat: initial 4 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*
 * Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
 *
 * This file is licensed under GPL-3.0-or-later, see NOTICE and LICENSE for
 * more information.
 */

package handlers

import (
	"net/http"

	"git.src.quest/~liljamo/emerwen-web/internal/auth"
	"git.src.quest/~liljamo/emerwen-web/internal/components"
	"github.com/alexedwards/scs/v2"
	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/gin-gonic/gin"
)

// OAuth2OIDCCallback returns a gin handler for the OAuth2 OIDC callback route.
func OAuth2OIDCCallback(a *auth.Auth, sm *scs.SessionManager) gin.HandlerFunc {
	verifier := a.Provider.Verifier(&oidc.Config{ClientID: a.Config.ClientID})

	return func(c *gin.Context) {
		if c.Query("state") != sm.GetString(c.Request.Context(), "state") {
			c.HTML(http.StatusBadRequest, "", components.Error(http.StatusBadRequest, "state mismatch"))
			return
		}

		oauth2Token, err := a.Config.Exchange(c.Request.Context(), c.Query("code"))
		if err != nil {
			c.HTML(http.StatusBadRequest, "", components.Error(http.StatusBadRequest, err.Error()))
			return
		}

		rawIDToken, ok := oauth2Token.Extra("id_token").(string)
		if !ok {
			c.HTML(http.StatusBadRequest, "", components.Error(http.StatusBadRequest, "no id_token"))
			return
		}

		idToken, err := verifier.Verify(c.Request.Context(), rawIDToken)
		if err != nil {
			c.HTML(http.StatusBadRequest, "", components.Error(http.StatusBadRequest, err.Error()))
			return
		}

		var claims auth.Claims
		if err := idToken.Claims(&claims); err != nil {
			c.HTML(http.StatusInternalServerError, "", components.Error(http.StatusInternalServerError, err.Error()))
			return
		}

		sm.Put(c.Request.Context(), "claims", claims)

		c.Redirect(http.StatusTemporaryRedirect, "/")
	}
}