From 456274c1b9d90f8e688d9fb4f4207963dbecf3e4 Mon Sep 17 00:00:00 2001 From: mutantmonkey <67266+mutantmonkey@users.noreply.github.com> Date: Fri, 14 Aug 2020 07:42:45 +0000 Subject: [PATCH] Split and move auth into a separate package (#224) * Split and move auth into a separate package This change will make it easier to implement additional authentication methods, such as OpenID Connect. For now, only the existing "apikeys" authentication method is supported. * Use absolute site prefix to prevent redirect loop --- auth.go => auth/apikeys/apikeys.go | 77 +++++++++++++------- auth_test.go => auth/apikeys/apikeys_test.go | 8 +- server.go | 27 ++----- upload.go | 12 ++- 4 files changed, 68 insertions(+), 56 deletions(-) rename auth.go => auth/apikeys/apikeys.go (53%) rename auth_test.go => auth/apikeys/apikeys_test.go (64%) diff --git a/auth.go b/auth/apikeys/apikeys.go similarity index 53% rename from auth.go rename to auth/apikeys/apikeys.go index 3dc5ba6..d2a592d 100644 --- a/auth.go +++ b/auth/apikeys/apikeys.go @@ -1,4 +1,4 @@ -package main +package apikeys import ( "bufio" @@ -24,16 +24,18 @@ const ( type AuthOptions struct { AuthFile string UnauthMethods []string + BasicAuth bool + SiteName string + SitePath string } -type auth struct { +type ApiKeysMiddleware struct { successHandler http.Handler - failureHandler http.Handler authKeys []string o AuthOptions } -func readAuthKeys(authFile string) []string { +func ReadAuthKeys(authFile string) []string { var authKeys []string f, err := os.Open(authFile) @@ -55,7 +57,7 @@ func readAuthKeys(authFile string) []string { return authKeys } -func checkAuth(authKeys []string, key string) (result bool, err error) { +func CheckAuth(authKeys []string, key string) (result bool, err error) { checkKey, err := scrypt.Key([]byte(key), []byte(scryptSalt), scryptN, scryptr, scryptp, scryptKeyLen) if err != nil { return @@ -73,53 +75,74 @@ func checkAuth(authKeys []string, key string) (result bool, err error) { return } -func (a auth) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if sliceContains(a.o.UnauthMethods, r.Method) { +func (a ApiKeysMiddleware) getSitePrefix() string { + prefix := a.o.SitePath + if len(prefix) <= 0 || prefix[0] != '/' { + prefix = "/" + prefix + } + return prefix +} + +func (a ApiKeysMiddleware) goodAuthorizationHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", a.getSitePrefix()) + w.WriteHeader(http.StatusFound) +} + +func (a ApiKeysMiddleware) badAuthorizationHandler(w http.ResponseWriter, r *http.Request) { + if a.o.BasicAuth { + rs := "" + if a.o.SiteName != "" { + rs = fmt.Sprintf(` realm="%s"`, a.o.SiteName) + } + w.Header().Set("WWW-Authenticate", `Basic`+rs) + } + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) +} + +func (a ApiKeysMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var successHandler http.Handler + prefix := a.getSitePrefix() + + if r.URL.Path == prefix+"auth" { + successHandler = http.HandlerFunc(a.goodAuthorizationHandler) + } else { + successHandler = a.successHandler + } + + if sliceContains(a.o.UnauthMethods, r.Method) && r.URL.Path != prefix+"auth" { // allow unauthenticated methods - a.successHandler.ServeHTTP(w, r) + successHandler.ServeHTTP(w, r) return } key := r.Header.Get("Linx-Api-Key") - if key == "" && Config.basicAuth { + if key == "" && a.o.BasicAuth { _, password, ok := r.BasicAuth() if ok { key = password } } - result, err := checkAuth(a.authKeys, key) + result, err := CheckAuth(a.authKeys, key) if err != nil || !result { - a.failureHandler.ServeHTTP(w, r) + http.HandlerFunc(a.badAuthorizationHandler).ServeHTTP(w, r) return } - a.successHandler.ServeHTTP(w, r) + successHandler.ServeHTTP(w, r) } -func UploadAuth(o AuthOptions) func(*web.C, http.Handler) http.Handler { +func NewApiKeysMiddleware(o AuthOptions) func(*web.C, http.Handler) http.Handler { fn := func(c *web.C, h http.Handler) http.Handler { - return auth{ + return ApiKeysMiddleware{ successHandler: h, - failureHandler: http.HandlerFunc(badAuthorizationHandler), - authKeys: readAuthKeys(o.AuthFile), + authKeys: ReadAuthKeys(o.AuthFile), o: o, } } return fn } -func badAuthorizationHandler(w http.ResponseWriter, r *http.Request) { - if Config.basicAuth { - rs := "" - if Config.siteName != "" { - rs = fmt.Sprintf(` realm="%s"`, Config.siteName) - } - w.Header().Set("WWW-Authenticate", `Basic`+rs) - } - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) -} - func sliceContains(slice []string, s string) bool { for _, v := range slice { if s == v { diff --git a/auth_test.go b/auth/apikeys/apikeys_test.go similarity index 64% rename from auth_test.go rename to auth/apikeys/apikeys_test.go index ded98b0..3c2b8e6 100644 --- a/auth_test.go +++ b/auth/apikeys/apikeys_test.go @@ -1,4 +1,4 @@ -package main +package apikeys import ( "testing" @@ -10,15 +10,15 @@ func TestCheckAuth(t *testing.T) { "vFpNprT9wbHgwAubpvRxYCCpA2FQMAK6hFqPvAGrdZo=", } - if r, err := checkAuth(authKeys, ""); err != nil && r { + if r, err := CheckAuth(authKeys, ""); err != nil && r { t.Fatal("Authorization passed for empty key") } - if r, err := checkAuth(authKeys, "thisisnotvalid"); err != nil && r { + if r, err := CheckAuth(authKeys, "thisisnotvalid"); err != nil && r { t.Fatal("Authorization passed for invalid key") } - if r, err := checkAuth(authKeys, "haPVipRnGJ0QovA9nyqK"); err != nil && !r { + if r, err := CheckAuth(authKeys, "haPVipRnGJ0QovA9nyqK"); err != nil && !r { t.Fatal("Authorization failed for valid key") } } diff --git a/server.go b/server.go index dae3491..4d06db9 100644 --- a/server.go +++ b/server.go @@ -16,6 +16,7 @@ import ( "time" rice "github.com/GeertJohan/go.rice" + "github.com/andreimarcu/linx-server/auth/apikeys" "github.com/andreimarcu/linx-server/backends" "github.com/andreimarcu/linx-server/backends/localfs" "github.com/andreimarcu/linx-server/backends/s3" @@ -110,9 +111,12 @@ func setup() *web.Mux { mux.Use(AddHeaders(Config.addHeaders)) if Config.authFile != "" { - mux.Use(UploadAuth(AuthOptions{ + mux.Use(apikeys.NewApiKeysMiddleware(apikeys.AuthOptions{ AuthFile: Config.authFile, UnauthMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, + BasicAuth: Config.basicAuth, + SiteName: Config.siteName, + SitePath: Config.sitePath, })) } @@ -196,29 +200,10 @@ func setup() *web.Mux { mux.Get(Config.sitePath+"upload/", uploadRemote) if Config.remoteAuthFile != "" { - remoteAuthKeys = readAuthKeys(Config.remoteAuthFile) + remoteAuthKeys = apikeys.ReadAuthKeys(Config.remoteAuthFile) } } - if Config.basicAuth { - options := AuthOptions{ - AuthFile: Config.authFile, - UnauthMethods: []string{}, - } - okFunc := func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Location", Config.sitePath) - w.WriteHeader(http.StatusFound) - } - authHandler := auth{ - successHandler: http.HandlerFunc(okFunc), - failureHandler: http.HandlerFunc(badAuthorizationHandler), - authKeys: readAuthKeys(Config.authFile), - o: options, - } - mux.Head(Config.sitePath+"auth", authHandler) - mux.Get(Config.sitePath+"auth", authHandler) - } - mux.Post(Config.sitePath+"upload", uploadPostHandler) mux.Post(Config.sitePath+"upload/", uploadPostHandler) mux.Put(Config.sitePath+"upload", uploadPutHandler) diff --git a/upload.go b/upload.go index 3cac122..bda0d7f 100644 --- a/upload.go +++ b/upload.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/andreimarcu/linx-server/auth/apikeys" "github.com/andreimarcu/linx-server/backends" "github.com/andreimarcu/linx-server/expiry" "github.com/dchest/uniuri" @@ -166,13 +167,16 @@ func uploadRemote(c web.C, w http.ResponseWriter, r *http.Request) { key = password } } - result, err := checkAuth(remoteAuthKeys, key) + result, err := apikeys.CheckAuth(remoteAuthKeys, key) if err != nil || !result { if Config.basicAuth { - badAuthorizationHandler(w, r) - } else { - unauthorizedHandler(c, w, r) + rs := "" + if Config.siteName != "" { + rs = fmt.Sprintf(` realm="%s"`, Config.siteName) + } + w.Header().Set("WWW-Authenticate", `Basic`+rs) } + unauthorizedHandler(c, w, r) return } }