ente/server/pkg/middleware/auth.go

89 lines
2.7 KiB
Go
Raw Normal View History

2024-03-01 08:07:01 +00:00
package middleware
import (
"fmt"
"net/http"
"strconv"
"github.com/ente-io/museum/ente/jwt"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/museum/pkg/controller/user"
"github.com/ente-io/museum/pkg/repo"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/spf13/viper"
)
// AuthMiddleware intercepts and authenticates incoming requests
type AuthMiddleware struct {
UserAuthRepo *repo.UserAuthRepository
Cache *cache.Cache
UserController *user.UserController
}
// TokenAuthMiddleware returns a middle ware that extracts the `X-AuthToken`
// within the header of a request and uses it to authenticate and insert the
// authenticated user to the request's `X-Auth-User-ID` field.
// If isJWT is true we use JWT token validation
func (m *AuthMiddleware) TokenAuthMiddleware(jwtClaimScope *jwt.ClaimScope) gin.HandlerFunc {
return func(c *gin.Context) {
token := auth.GetToken(c)
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
return
}
app := auth.GetApp(c)
cacheKey := fmt.Sprintf("%s:%s", app, token)
isJWT := false
if jwtClaimScope != nil {
isJWT = true
cacheKey = fmt.Sprintf("%s:%s:%s", app, token, *jwtClaimScope)
}
userID, found := m.Cache.Get(cacheKey)
var err error
if !found {
if isJWT {
userID, err = m.UserController.ValidateJWTToken(token, *jwtClaimScope)
} else {
userID, err = m.UserAuthRepo.GetUserIDWithToken(token, app)
}
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if !isJWT {
ip := network.GetClientIP(c)
userAgent := c.Request.UserAgent()
// skip updating last used for requests routed via CF worker
if !network.IsCFWorkerIP(ip) {
go func() {
_ = m.UserAuthRepo.UpdateLastUsedAt(userID.(int64), token, ip, userAgent)
}()
}
}
m.Cache.Set(cacheKey, userID, cache.DefaultExpiration)
}
c.Request.Header.Set("X-Auth-User-ID", strconv.FormatInt(userID.(int64), 10))
c.Next()
}
}
// AdminAuthMiddleware returns a middle ware that extracts the `userID` added by the TokenAuthMiddleware
// within the header of a request and uses it to check admin status
// NOTE: Should be added after TokenAuthMiddleware middleware
func (m *AuthMiddleware) AdminAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
userID := auth.GetUserID(c.Request.Header)
admins := viper.GetIntSlice("internal.admins")
for _, admin := range admins {
if int64(admin) == userID {
c.Next()
return
}
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "insufficient permissions"})
}
}