2023-03-10 20:59:56 +00:00
package utils
import (
"context"
"net/http"
"time"
2023-04-30 12:03:14 +00:00
"net"
2023-05-07 16:47:20 +00:00
"strings"
2023-05-06 18:32:49 +00:00
"fmt"
2023-04-30 12:03:14 +00:00
2023-04-18 15:50:12 +00:00
"github.com/mxk/go-flowrate/flowrate"
2023-04-30 12:03:14 +00:00
"github.com/oschwald/geoip2-golang"
2023-03-10 20:59:56 +00:00
)
// https://github.com/go-chi/chi/blob/master/middleware/timeout.go
func MiddlewareTimeout ( timeout time . Duration ) func ( next http . Handler ) http . Handler {
return func ( next http . Handler ) http . Handler {
fn := func ( w http . ResponseWriter , r * http . Request ) {
ctx , cancel := context . WithTimeout ( r . Context ( ) , timeout )
defer func ( ) {
cancel ( )
if ctx . Err ( ) == context . DeadlineExceeded {
Error ( "Request Timeout. Cancelling." , ctx . Err ( ) )
HTTPError ( w , "Gateway Timeout" ,
http . StatusGatewayTimeout , "HTTP002" )
return
}
} ( )
2023-04-04 20:54:35 +00:00
w . Header ( ) . Set ( "X-Timeout-Duration" , timeout . String ( ) )
2023-03-10 20:59:56 +00:00
r = r . WithContext ( ctx )
next . ServeHTTP ( w , r )
}
return http . HandlerFunc ( fn )
}
}
2023-04-18 15:50:12 +00:00
type responseWriter struct {
http . ResponseWriter
* flowrate . Writer
}
func ( w * responseWriter ) Write ( b [ ] byte ) ( int , error ) {
return w . Writer . Write ( b )
}
func BandwithLimiterMiddleware ( max int64 ) func ( next http . Handler ) http . Handler {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
if ( max > 0 ) {
fw := flowrate . NewWriter ( w , max )
w = & responseWriter { w , fw }
}
next . ServeHTTP ( w , r )
} )
}
}
2023-03-10 20:59:56 +00:00
func SetSecurityHeaders ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
if ( IsHTTPS ) {
// TODO: Add preload if we have a valid certificate
w . Header ( ) . Set ( "Strict-Transport-Security" , "max-age=31536000; includeSubDomains" )
}
2023-04-27 18:29:26 +00:00
2023-03-10 20:59:56 +00:00
w . Header ( ) . Set ( "X-Content-Type-Options" , "nosniff" )
2023-06-21 00:07:44 +00:00
w . Header ( ) . Set ( "Content-Security-Policy" , "frame-ancestors 'self'" )
2023-03-10 20:59:56 +00:00
w . Header ( ) . Set ( "X-XSS-Protection" , "1; mode=block" )
2023-03-18 19:59:32 +00:00
w . Header ( ) . Set ( "X-Served-By-Cosmos" , "1" )
2023-03-10 20:59:56 +00:00
next . ServeHTTP ( w , r )
} )
}
func CORSHeader ( origin string ) func ( next http . Handler ) http . Handler {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2023-03-18 19:59:32 +00:00
2023-03-10 20:59:56 +00:00
w . Header ( ) . Set ( "Access-Control-Allow-Origin" , origin )
w . Header ( ) . Set ( "Access-Control-Allow-Methods" , "GET, POST, PUT, DELETE, OPTIONS" )
w . Header ( ) . Set ( "Access-Control-Allow-Headers" , "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization" )
w . Header ( ) . Set ( "Access-Control-Allow-Credentials" , "true" )
next . ServeHTTP ( w , r )
} )
}
}
func AcceptHeader ( accept string ) func ( next http . Handler ) http . Handler {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "Content-Type" , accept )
next . ServeHTTP ( w , r )
} )
}
}
2023-04-30 12:03:14 +00:00
// GetIPLocation returns the ISO country code for a given IP address.
func GetIPLocation ( ip string ) ( string , error ) {
geoDB , err := geoip2 . Open ( "GeoLite2-Country.mmdb" )
if err != nil {
return "" , err
}
defer geoDB . Close ( )
parsedIP := net . ParseIP ( ip )
record , err := geoDB . Country ( parsedIP )
if err != nil {
return "" , err
}
return record . Country . IsoCode , nil
}
// BlockByCountryMiddleware returns a middleware function that blocks requests from specified countries.
func BlockByCountryMiddleware ( blockedCountries [ ] string ) func ( http . Handler ) http . Handler {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
ip , _ , err := net . SplitHostPort ( r . RemoteAddr )
if err != nil {
http . Error ( w , "Invalid request" , http . StatusBadRequest )
return
}
countryCode , err := GetIPLocation ( ip )
if err == nil {
if countryCode == "" {
Debug ( "Country code is empty" )
} else {
Debug ( "Country code: " + countryCode )
}
config := GetMainConfig ( )
for _ , blockedCountry := range blockedCountries {
if config . ServerCountry != countryCode && countryCode == blockedCountry {
http . Error ( w , "Access denied" , http . StatusForbidden )
return
}
}
} else {
Warn ( "Missing geolocation information to block IPs" )
}
next . ServeHTTP ( w , r )
} )
}
}
// blockPostWithoutReferer blocks POST requests without a Referer header
func BlockPostWithoutReferer ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
if r . Method == "POST" || r . Method == "PUT" || r . Method == "PATCH" || r . Method == "DELETE" {
referer := r . Header . Get ( "Referer" )
if referer == "" {
http . Error ( w , "Bad Request: Invalid request." , http . StatusBadRequest )
return
}
}
// If it's not a POST request or the POST request has a Referer header, pass the request to the next handler
next . ServeHTTP ( w , r )
} )
2023-05-01 10:00:45 +00:00
}
2023-05-06 18:25:10 +00:00
func EnsureHostname ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2023-05-10 17:17:11 +00:00
Debug ( "Ensuring origin for requested resource from : " + r . Host )
2023-05-06 18:25:10 +00:00
og := GetMainConfig ( ) . HTTPConfig . Hostname
ni := GetMainConfig ( ) . NewInstall
if ni || og == "0.0.0.0" {
next . ServeHTTP ( w , r )
return
}
2023-05-18 12:42:47 +00:00
hostnames := GetAllHostnames ( false , false )
2023-05-06 18:25:10 +00:00
2023-05-07 16:47:20 +00:00
reqHostNoPort := strings . Split ( r . Host , ":" ) [ 0 ]
2023-05-06 18:32:49 +00:00
isOk := false
2023-05-06 18:25:10 +00:00
for _ , hostname := range hostnames {
2023-05-10 17:17:11 +00:00
hostnameNoPort := strings . Split ( hostname , ":" ) [ 0 ]
if reqHostNoPort == hostnameNoPort {
2023-05-06 18:32:49 +00:00
isOk = true
2023-05-06 18:25:10 +00:00
}
}
2023-05-06 18:32:49 +00:00
if ! isOk {
Error ( "Invalid Hostname " + r . Host + " for request. Expecting one of " + fmt . Sprintf ( "%v" , hostnames ) , nil )
w . WriteHeader ( http . StatusBadRequest )
2023-07-04 10:52:38 +00:00
http . Error ( w , "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed." , http . StatusBadRequest )
2023-05-06 18:32:49 +00:00
return
}
2023-05-06 18:25:10 +00:00
next . ServeHTTP ( w , r )
} )
2023-05-16 17:08:01 +00:00
}
func IsValidHostname ( hostname string ) bool {
og := GetMainConfig ( ) . HTTPConfig . Hostname
ni := GetMainConfig ( ) . NewInstall
if ni || og == "0.0.0.0" {
return true
}
2023-05-18 12:42:47 +00:00
hostnames := GetAllHostnames ( false , false )
2023-05-16 17:08:01 +00:00
reqHostNoPort := strings . Split ( hostname , ":" ) [ 0 ]
reqHostNoPortNoSubdomain := ""
if parts := strings . Split ( reqHostNoPort , "." ) ; len ( parts ) < 2 {
reqHostNoPortNoSubdomain = reqHostNoPort
} else {
reqHostNoPortNoSubdomain = parts [ len ( parts ) - 2 ] + "." + parts [ len ( parts ) - 1 ]
}
for _ , hostname := range hostnames {
hostnameNoPort := strings . Split ( hostname , ":" ) [ 0 ]
hostnameNoPortNoSubdomain := ""
if parts := strings . Split ( hostnameNoPort , "." ) ; len ( parts ) < 2 {
hostnameNoPortNoSubdomain = hostnameNoPort
} else {
hostnameNoPortNoSubdomain = parts [ len ( parts ) - 2 ] + "." + parts [ len ( parts ) - 1 ]
}
if reqHostNoPortNoSubdomain == hostnameNoPortNoSubdomain {
return true
}
}
return false
2023-05-06 18:25:10 +00:00
}