2023-04-18 15:50:12 +00:00
|
|
|
package proxy
|
|
|
|
|
|
|
|
import (
|
|
|
|
"github.com/azukaar/cosmos-server/src/utils"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
"net/http"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
2023-05-19 16:01:11 +00:00
|
|
|
"os"
|
2023-04-20 18:35:06 +00:00
|
|
|
"math"
|
|
|
|
"strconv"
|
2023-04-18 15:50:12 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
/*
|
|
|
|
TODO :
|
|
|
|
- Recalculate throttle every gb for writer wrapper?
|
|
|
|
*/
|
|
|
|
|
|
|
|
const (
|
|
|
|
STRIKE = 0
|
|
|
|
TEMP = 1
|
|
|
|
PERM = 2
|
|
|
|
)
|
|
|
|
type userBan struct {
|
|
|
|
ClientID string
|
|
|
|
banType int
|
|
|
|
time time.Time
|
2023-05-13 17:38:39 +00:00
|
|
|
reason string
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
type smartShieldState struct {
|
|
|
|
sync.Mutex
|
|
|
|
requests []*SmartResponseWriterWrapper
|
|
|
|
bans []*userBan
|
|
|
|
}
|
|
|
|
|
|
|
|
type userUsedBudget struct {
|
|
|
|
ClientID string
|
|
|
|
Time float64
|
|
|
|
Requests int
|
|
|
|
Bytes int64
|
2023-04-30 12:03:14 +00:00
|
|
|
Simultaneous int
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var shield smartShieldState
|
|
|
|
|
2023-04-30 12:03:14 +00:00
|
|
|
func (shield *smartShieldState) GetServerNbReq() int {
|
|
|
|
shield.Lock()
|
|
|
|
defer shield.Unlock()
|
|
|
|
nbRequests := 0
|
|
|
|
|
|
|
|
for i := len(shield.requests) - 1; i >= 0; i-- {
|
|
|
|
request := shield.requests[i]
|
|
|
|
if(request.IsOld()) {
|
|
|
|
return nbRequests
|
|
|
|
}
|
|
|
|
if(!request.IsOver()) {
|
|
|
|
nbRequests++
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nbRequests
|
|
|
|
}
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
func (shield *smartShieldState) GetUserUsedBudgets(ClientID string) userUsedBudget {
|
|
|
|
shield.Lock()
|
|
|
|
defer shield.Unlock()
|
|
|
|
|
|
|
|
userConsumed := userUsedBudget{
|
|
|
|
ClientID: ClientID,
|
|
|
|
Time: 0,
|
|
|
|
Requests: 0,
|
|
|
|
Bytes: 0,
|
2023-04-30 12:03:14 +00:00
|
|
|
Simultaneous: 0,
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Check for recent requests
|
|
|
|
for i := len(shield.requests) - 1; i >= 0; i-- {
|
|
|
|
request := shield.requests[i]
|
|
|
|
if(request.IsOld()) {
|
|
|
|
return userConsumed
|
|
|
|
}
|
|
|
|
if request.ClientID == ClientID && !request.IsOld() {
|
|
|
|
if(request.IsOver()) {
|
|
|
|
userConsumed.Time += request.TimeEnded.Sub(request.TimeStarted).Seconds()
|
|
|
|
} else {
|
|
|
|
userConsumed.Time += time.Now().Sub(request.TimeStarted).Seconds()
|
2023-04-30 12:03:14 +00:00
|
|
|
userConsumed.Simultaneous++
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
userConsumed.Requests += request.RequestCost
|
|
|
|
userConsumed.Bytes += request.Bytes
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return userConsumed
|
|
|
|
}
|
|
|
|
|
2023-05-13 17:38:39 +00:00
|
|
|
func (shield *smartShieldState) GetLastBan(policy utils.SmartShieldPolicy, userConsumed userUsedBudget) *userBan {
|
|
|
|
shield.Lock()
|
|
|
|
defer shield.Unlock()
|
|
|
|
|
|
|
|
ClientID := userConsumed.ClientID
|
|
|
|
|
|
|
|
// Check for bans
|
|
|
|
for i := len(shield.bans) - 1; i >= 0; i-- {
|
|
|
|
ban := shield.bans[i]
|
|
|
|
if ban.banType == STRIKE && ban.ClientID == ClientID {
|
|
|
|
return ban
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
func (shield *smartShieldState) isAllowedToReqest(policy utils.SmartShieldPolicy, userConsumed userUsedBudget) bool {
|
|
|
|
shield.Lock()
|
|
|
|
defer shield.Unlock()
|
|
|
|
|
|
|
|
ClientID := userConsumed.ClientID
|
2023-07-22 14:46:43 +00:00
|
|
|
|
|
|
|
if ClientID == "192.168.1.1" ||
|
|
|
|
ClientID == "192.168.0.1" ||
|
|
|
|
ClientID == "172.17.0.1" {
|
|
|
|
return true
|
|
|
|
}
|
2023-04-18 15:50:12 +00:00
|
|
|
|
|
|
|
nbTempBans := 0
|
|
|
|
nbStrikes := 0
|
|
|
|
|
|
|
|
// Check for bans
|
|
|
|
for i := len(shield.bans) - 1; i >= 0; i-- {
|
|
|
|
ban := shield.bans[i]
|
2023-05-13 17:38:39 +00:00
|
|
|
if ban.banType == PERM && ban.ClientID == ClientID {
|
2023-04-18 15:50:12 +00:00
|
|
|
return false
|
2023-05-13 17:38:39 +00:00
|
|
|
} else if ban.banType == TEMP && ban.ClientID == ClientID {
|
2023-04-18 15:50:12 +00:00
|
|
|
if(ban.time.Add(4 * 3600 * time.Second).Before(time.Now())) {
|
|
|
|
return false
|
|
|
|
} else if (ban.time.Add(72 * 3600 * time.Second).Before(time.Now())) {
|
|
|
|
nbTempBans++
|
|
|
|
}
|
2023-05-13 17:38:39 +00:00
|
|
|
} else if ban.banType == STRIKE && ban.ClientID == ClientID {
|
2023-04-18 15:50:12 +00:00
|
|
|
if(ban.time.Add(3600 * time.Second).Before(time.Now())) {
|
|
|
|
return false
|
|
|
|
} else if (ban.time.Add(24 * 3600 * time.Second).Before(time.Now())) {
|
|
|
|
nbStrikes++
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check for new bans
|
|
|
|
if nbTempBans >= 3 {
|
|
|
|
// perm ban
|
|
|
|
shield.bans = append(shield.bans, &userBan{
|
|
|
|
ClientID: ClientID,
|
|
|
|
banType: PERM,
|
|
|
|
time: time.Now(),
|
|
|
|
})
|
2023-05-01 10:05:35 +00:00
|
|
|
utils.Warn("User " + ClientID + " has been banned permanently: "+ fmt.Sprintf("%+v", userConsumed))
|
2023-04-18 15:50:12 +00:00
|
|
|
return false
|
|
|
|
} else if nbStrikes >= 3 {
|
|
|
|
// temp ban
|
|
|
|
shield.bans = append(shield.bans, &userBan{
|
|
|
|
ClientID: ClientID,
|
|
|
|
banType: TEMP,
|
|
|
|
time: time.Now(),
|
|
|
|
})
|
2023-05-01 10:05:35 +00:00
|
|
|
utils.Warn("User " + ClientID + " has been banned temporarily: "+ fmt.Sprintf("%+v", userConsumed))
|
2023-04-18 15:50:12 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check for new strikes
|
|
|
|
if (userConsumed.Time > (policy.PerUserTimeBudget * float64(policy.PolicyStrictness))) ||
|
|
|
|
(userConsumed.Requests > (policy.PerUserRequestLimit * policy.PolicyStrictness)) ||
|
2023-04-30 12:03:14 +00:00
|
|
|
(userConsumed.Bytes > (policy.PerUserByteLimit * int64(policy.PolicyStrictness))) ||
|
2023-05-01 12:30:39 +00:00
|
|
|
(userConsumed.Simultaneous > (policy.PerUserSimultaneous * policy.PolicyStrictness * 15)) {
|
2023-04-18 15:50:12 +00:00
|
|
|
shield.bans = append(shield.bans, &userBan{
|
|
|
|
ClientID: ClientID,
|
|
|
|
banType: STRIKE,
|
|
|
|
time: time.Now(),
|
2023-05-13 17:38:39 +00:00
|
|
|
reason: fmt.Sprintf("%+v out of %+v", userConsumed, policy),
|
2023-04-18 15:50:12 +00:00
|
|
|
})
|
2023-05-01 10:05:35 +00:00
|
|
|
utils.Warn("User " + ClientID + " has received a strike: "+ fmt.Sprintf("%+v", userConsumed))
|
2023-04-18 15:50:12 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
func (shield *smartShieldState) computeThrottle(policy utils.SmartShieldPolicy, userConsumed userUsedBudget) int {
|
|
|
|
shield.Lock()
|
|
|
|
defer shield.Unlock()
|
|
|
|
|
|
|
|
throttle := 0
|
|
|
|
|
|
|
|
overReq := policy.PerUserRequestLimit - userConsumed.Requests
|
|
|
|
overReqRatio := float64(overReq) / float64(policy.PerUserRequestLimit)
|
|
|
|
if overReq < 0 {
|
|
|
|
newThrottle := int(float64(2500) * -overReqRatio)
|
|
|
|
if newThrottle > throttle {
|
|
|
|
throttle = newThrottle
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
overByte := policy.PerUserByteLimit - userConsumed.Bytes
|
|
|
|
overByteRatio := float64(overByte) / float64(policy.PerUserByteLimit)
|
|
|
|
if overByte < 0 {
|
2023-04-30 12:03:14 +00:00
|
|
|
newThrottle := int(float64(40) * -overByteRatio)
|
|
|
|
if newThrottle > throttle {
|
|
|
|
throttle = newThrottle
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
overSim := policy.PerUserSimultaneous - userConsumed.Simultaneous
|
|
|
|
overSimRatio := float64(overSim) / float64(policy.PerUserSimultaneous)
|
|
|
|
if overSim < 0 {
|
2023-04-30 23:56:10 +00:00
|
|
|
newThrottle := int(float64(50) * -overSimRatio)
|
2023-04-18 15:50:12 +00:00
|
|
|
if newThrottle > throttle {
|
|
|
|
throttle = newThrottle
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if throttle > 0 {
|
|
|
|
utils.Debug(fmt.Sprintf("User Time: %f, Requests: %d, Bytes: %d", userConsumed.Time, userConsumed.Requests, userConsumed.Bytes))
|
|
|
|
utils.Debug(fmt.Sprintf("Policy Time: %f, Requests: %d, Bytes: %d", policy.PerUserTimeBudget, policy.PerUserRequestLimit, policy.PerUserByteLimit))
|
|
|
|
utils.Debug(fmt.Sprintf("Throttling: %d", throttle))
|
|
|
|
}
|
|
|
|
|
|
|
|
return throttle
|
|
|
|
}
|
|
|
|
|
2023-04-20 18:35:06 +00:00
|
|
|
func calculateLowestExhaustedPercentage(policy utils.SmartShieldPolicy, userConsumed userUsedBudget) int64 {
|
|
|
|
timeExhaustedPercentage := 100 - (userConsumed.Time / policy.PerUserTimeBudget) * 100
|
|
|
|
requestsExhaustedPercentage := 100 - (float64(userConsumed.Requests) / float64(policy.PerUserRequestLimit)) * 100
|
|
|
|
bytesExhaustedPercentage := 100 - (float64(userConsumed.Bytes) / float64(policy.PerUserByteLimit)) * 100
|
|
|
|
|
|
|
|
// utils.Debug(fmt.Sprintf("Time: %f, Requests: %d, Bytes: %d", timeExhaustedPercentage, requestsExhaustedPercentage, bytesExhaustedPercentage))
|
|
|
|
|
|
|
|
return int64(math.Max(0, math.Min(math.Min(timeExhaustedPercentage, requestsExhaustedPercentage), bytesExhaustedPercentage)))
|
|
|
|
}
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
func GetClientID(r *http.Request) string {
|
2023-05-19 16:01:11 +00:00
|
|
|
// when using Docker we need to get the real IP
|
2023-07-19 10:27:48 +00:00
|
|
|
utils.Debug("SmartShield TEMPLOG: Getting client ID")
|
|
|
|
utils.Debug("SmartShield TEMPLOG HOSTNAME: " + os.Getenv("HOSTNAME"))
|
|
|
|
utils.Debug("SmartShield TEMPLOG x-forwarded-for: " + r.Header.Get("x-forwarded-for"))
|
|
|
|
utils.Debug("SmartShield TEMPLOG RemoteAddr: " + r.RemoteAddr)
|
|
|
|
|
2023-05-20 19:58:25 +00:00
|
|
|
if os.Getenv("HOSTNAME") != "" && r.Header.Get("x-forwarded-for") != "" {
|
2023-05-19 16:01:11 +00:00
|
|
|
ip, _, _ := net.SplitHostPort(r.Header.Get("x-forwarded-for"))
|
|
|
|
utils.Debug("SmartShield: Getting client ID " + ip)
|
|
|
|
return ip
|
|
|
|
} else {
|
|
|
|
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
|
|
utils.Debug("SmartShield: Getting client ID " + ip)
|
|
|
|
return ip
|
|
|
|
}
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
|
2023-04-30 12:03:14 +00:00
|
|
|
func isPrivileged(req *http.Request, policy utils.SmartShieldPolicy) bool {
|
|
|
|
role, _ := strconv.Atoi(req.Header.Get("x-cosmos-role"))
|
|
|
|
return role >= policy.PrivilegedGroups
|
|
|
|
}
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
func SmartShieldMiddleware(policy utils.SmartShieldPolicy) func(http.Handler) http.Handler {
|
|
|
|
if policy.Enabled == false {
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
return next
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(policy.PerUserTimeBudget == 0) {
|
|
|
|
policy.PerUserTimeBudget = 2 * 60 * 60 * 1000 // 2 hours
|
|
|
|
}
|
|
|
|
if(policy.PerUserRequestLimit == 0) {
|
|
|
|
policy.PerUserRequestLimit = 6000 // 100 requests per minute
|
|
|
|
}
|
|
|
|
if(policy.PerUserByteLimit == 0) {
|
2023-04-30 12:03:14 +00:00
|
|
|
policy.PerUserByteLimit = 150 * 1024 * 1024 * 1024 // 150GB
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
if(policy.PolicyStrictness == 0) {
|
|
|
|
policy.PolicyStrictness = 2 // NORMAL
|
|
|
|
}
|
2023-04-30 12:03:14 +00:00
|
|
|
if(policy.PerUserSimultaneous == 0) {
|
2023-04-30 23:56:10 +00:00
|
|
|
policy.PerUserSimultaneous = 24
|
2023-04-30 12:03:14 +00:00
|
|
|
}
|
|
|
|
if(policy.MaxGlobalSimultaneous == 0) {
|
2023-04-30 23:56:10 +00:00
|
|
|
policy.MaxGlobalSimultaneous = 250
|
2023-04-30 12:03:14 +00:00
|
|
|
}
|
|
|
|
if(policy.PrivilegedGroups == 0) {
|
|
|
|
policy.PrivilegedGroups = utils.ADMIN
|
|
|
|
}
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
2023-04-30 12:03:14 +00:00
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
utils.Log("SmartShield: Request received")
|
2023-04-30 12:03:14 +00:00
|
|
|
currentGlobalRequests := shield.GetServerNbReq() + 1
|
|
|
|
utils.Debug(fmt.Sprintf("SmartShield: Current global requests: %d", currentGlobalRequests))
|
|
|
|
|
2023-05-01 12:30:39 +00:00
|
|
|
if !isPrivileged(r, policy) {
|
|
|
|
tooManyReq := currentGlobalRequests > policy.MaxGlobalSimultaneous
|
|
|
|
wayTooManyReq := currentGlobalRequests > policy.MaxGlobalSimultaneous * 10
|
|
|
|
retries := 50
|
|
|
|
if wayTooManyReq {
|
|
|
|
utils.Log("SmartShield: WAYYYY Too many users on the server. Aborting right away.")
|
|
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
for tooManyReq {
|
|
|
|
time.Sleep(5000 * time.Millisecond)
|
|
|
|
currentGlobalRequests := shield.GetServerNbReq() + 1
|
|
|
|
tooManyReq = currentGlobalRequests > policy.MaxGlobalSimultaneous
|
|
|
|
retries--
|
|
|
|
if retries <= 0 {
|
|
|
|
utils.Log("SmartShield: Too many users on the server")
|
|
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
2023-04-30 12:03:14 +00:00
|
|
|
}
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
clientID := GetClientID(r)
|
|
|
|
userConsumed := shield.GetUserUsedBudgets(clientID)
|
|
|
|
|
2023-04-30 12:03:14 +00:00
|
|
|
if !isPrivileged(r, policy) && !shield.isAllowedToReqest(policy, userConsumed) {
|
2023-05-13 17:38:39 +00:00
|
|
|
lastBan := shield.GetLastBan(policy, userConsumed)
|
|
|
|
utils.Log("SmartShield: User is blocked due to abuse: " + fmt.Sprintf("%+v", lastBan))
|
2023-04-18 15:50:12 +00:00
|
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
|
|
return
|
|
|
|
} else {
|
2023-04-30 12:03:14 +00:00
|
|
|
throttle := 0
|
|
|
|
if(!isPrivileged(r, policy)) {
|
|
|
|
throttle = shield.computeThrottle(policy, userConsumed)
|
|
|
|
}
|
2023-04-18 15:50:12 +00:00
|
|
|
wrapper := &SmartResponseWriterWrapper {
|
|
|
|
ResponseWriter: w,
|
|
|
|
ThrottleNext: throttle,
|
|
|
|
TimeStarted: time.Now(),
|
|
|
|
ClientID: clientID,
|
|
|
|
RequestCost: 1,
|
|
|
|
Method: r.Method,
|
|
|
|
shield: shield,
|
|
|
|
policy: policy,
|
2023-04-30 12:03:14 +00:00
|
|
|
isPrivileged: isPrivileged(r, policy),
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
|
2023-04-20 18:35:06 +00:00
|
|
|
// add rate limite headers
|
|
|
|
In20Minutes := strconv.FormatInt(time.Now().Add(20 * time.Minute).Unix(), 10)
|
|
|
|
w.Header().Set("X-RateLimit-Remaining", strconv.FormatInt(calculateLowestExhaustedPercentage(policy, userConsumed), 10))
|
|
|
|
w.Header().Set("X-RateLimit-Limit", strconv.FormatInt(int64(policy.PerUserRequestLimit), 10))
|
|
|
|
w.Header().Set("X-RateLimit-Reset", In20Minutes)
|
|
|
|
|
2023-04-18 15:50:12 +00:00
|
|
|
shield.Lock()
|
|
|
|
shield.requests = append(shield.requests, wrapper)
|
|
|
|
shield.Unlock()
|
|
|
|
|
2023-04-30 12:03:14 +00:00
|
|
|
ctx := r.Context()
|
|
|
|
done := make(chan struct{})
|
2023-04-18 15:50:12 +00:00
|
|
|
|
2023-04-30 12:03:14 +00:00
|
|
|
go (func() {
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
case <-done:
|
|
|
|
}
|
|
|
|
shield.Lock()
|
|
|
|
wrapper.TimeEnded = time.Now()
|
|
|
|
wrapper.isOver = true
|
|
|
|
shield.Unlock()
|
|
|
|
})()
|
|
|
|
|
|
|
|
next.ServeHTTP(wrapper, r)
|
|
|
|
close(done)
|
2023-04-18 15:50:12 +00:00
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|