Cosmos-Server/src/utils/middleware.go

205 lines
5.4 KiB
Go
Raw Normal View History

2023-03-10 20:59:56 +00:00
package utils
import (
"context"
"net/http"
"time"
"net"
2023-05-07 16:47:20 +00:00
"strings"
2023-05-06 18:32:49 +00:00
"fmt"
"github.com/mxk/go-flowrate/flowrate"
"github.com/oschwald/geoip2-golang"
2023-03-10 20:59:56 +00:00
)
// https://github.com/go-chi/chi/blob/master/middleware/timeout.go
func MiddlewareTimeout(timeout time.Duration) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer func() {
cancel()
if ctx.Err() == context.DeadlineExceeded {
Error("Request Timeout. Cancelling.", ctx.Err())
HTTPError(w, "Gateway Timeout",
http.StatusGatewayTimeout, "HTTP002")
return
}
}()
w.Header().Set("X-Timeout-Duration", timeout.String())
2023-03-10 20:59:56 +00:00
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
type responseWriter struct {
http.ResponseWriter
*flowrate.Writer
}
func (w *responseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func BandwithLimiterMiddleware(max int64) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if(max > 0) {
fw := flowrate.NewWriter(w, max)
w = &responseWriter{w, fw}
}
next.ServeHTTP(w, r)
})
}
}
2023-03-10 20:59:56 +00:00
func SetSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if(IsHTTPS) {
// TODO: Add preload if we have a valid certificate
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
2023-04-27 18:29:26 +00:00
2023-03-10 20:59:56 +00:00
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
2023-03-18 19:59:32 +00:00
w.Header().Set("X-Served-By-Cosmos", "1")
2023-03-10 20:59:56 +00:00
next.ServeHTTP(w, r)
})
}
func CORSHeader(origin string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2023-03-18 19:59:32 +00:00
2023-03-10 20:59:56 +00:00
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
next.ServeHTTP(w, r)
})
}
}
func AcceptHeader(accept string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", accept)
next.ServeHTTP(w, r)
})
}
}
// GetIPLocation returns the ISO country code for a given IP address.
func GetIPLocation(ip string) (string, error) {
geoDB, err := geoip2.Open("GeoLite2-Country.mmdb")
if err != nil {
return "", err
}
defer geoDB.Close()
parsedIP := net.ParseIP(ip)
record, err := geoDB.Country(parsedIP)
if err != nil {
return "", err
}
return record.Country.IsoCode, nil
}
// BlockByCountryMiddleware returns a middleware function that blocks requests from specified countries.
func BlockByCountryMiddleware(blockedCountries []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
countryCode, err := GetIPLocation(ip)
if err == nil {
if countryCode == "" {
Debug("Country code is empty")
} else {
Debug("Country code: " + countryCode)
}
config := GetMainConfig()
for _, blockedCountry := range blockedCountries {
if config.ServerCountry != countryCode && countryCode == blockedCountry {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
}
} else {
Warn("Missing geolocation information to block IPs")
}
next.ServeHTTP(w, r)
})
}
}
// blockPostWithoutReferer blocks POST requests without a Referer header
func BlockPostWithoutReferer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" || r.Method == "DELETE" {
referer := r.Header.Get("Referer")
if referer == "" {
http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest)
return
}
}
// If it's not a POST request or the POST request has a Referer header, pass the request to the next handler
next.ServeHTTP(w, r)
})
2023-05-01 10:00:45 +00:00
}
2023-05-06 18:25:10 +00:00
func EnsureHostname(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2023-05-10 17:17:11 +00:00
Debug("Ensuring origin for requested resource from : " + r.Host)
2023-05-06 18:25:10 +00:00
og := GetMainConfig().HTTPConfig.Hostname
ni := GetMainConfig().NewInstall
if ni || og == "0.0.0.0" {
next.ServeHTTP(w, r)
return
}
hostnames := GetAllHostnames()
2023-05-07 16:47:20 +00:00
reqHostNoPort := strings.Split(r.Host, ":")[0]
2023-05-06 18:32:49 +00:00
isOk := false
2023-05-06 18:25:10 +00:00
for _, hostname := range hostnames {
2023-05-10 17:17:11 +00:00
hostnameNoPort := strings.Split(hostname, ":")[0]
if reqHostNoPort == hostnameNoPort {
2023-05-06 18:32:49 +00:00
isOk = true
2023-05-06 18:25:10 +00:00
}
}
2023-05-06 18:32:49 +00:00
if !isOk {
Error("Invalid Hostname " + r.Host + " for request. Expecting one of " + fmt.Sprintf("%v", hostnames), nil)
w.WriteHeader(http.StatusBadRequest)
http.Error(w, "Bad Request: Invalid hostname.", http.StatusBadRequest)
return
}
2023-05-06 18:25:10 +00:00
next.ServeHTTP(w, r)
})
}