[server] Rate limit (#1121)

## Description

## Tests
This commit is contained in:
Neeraj Gupta 2024-03-17 09:59:59 +05:30 committed by GitHub
parent a5340764a8
commit 12b9ac4db6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 11 deletions

View file

@ -178,7 +178,8 @@ func main() {
authCache := cache.New(1*time.Minute, 15*time.Minute)
accessTokenCache := cache.New(1*time.Minute, 15*time.Minute)
discordController := discord.NewDiscordController(userRepo, hostName, environment)
rateLimiter := middleware.NewRateLimitMiddleware(discordController)
rateLimiter := middleware.NewRateLimitMiddleware(discordController, 5000, 5*time.Second)
defer rateLimiter.Stop()
emailNotificationCtrl := &email.EmailNotificationController{
UserRepo: userRepo,
@ -360,22 +361,22 @@ func main() {
server.Use(requestid.New(), middleware.Logger(urlSanitizer), cors(), gzip.Gzip(gzip.DefaultCompression), middleware.PanicRecover())
publicAPI := server.Group("/")
publicAPI.Use(rateLimiter.APIRateLimitMiddleware(urlSanitizer))
publicAPI.Use(rateLimiter.GlobalRateLimiter(), rateLimiter.APIRateLimitMiddleware(urlSanitizer))
privateAPI := server.Group("/")
privateAPI.Use(authMiddleware.TokenAuthMiddleware(nil), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
privateAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(nil), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
adminAPI := server.Group("/admin")
adminAPI.Use(authMiddleware.TokenAuthMiddleware(nil), authMiddleware.AdminAuthMiddleware())
adminAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(nil), authMiddleware.AdminAuthMiddleware())
paymentJwtAuthAPI := server.Group("/")
paymentJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.PAYMENT.Ptr()))
paymentJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.PAYMENT.Ptr()))
familiesJwtAuthAPI := server.Group("/")
//The middleware order matters. First, the userID must be set in the context, so that we can apply limit for user.
familiesJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
publicCollectionAPI := server.Group("/public-collection")
publicCollectionAPI.Use(accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
healthCheckHandler := &api.HealthCheckHandler{
DB: db,
@ -472,7 +473,7 @@ func main() {
privateAPI.DELETE("/users/delete", userHandler.DeleteUser)
accountsJwtAuthAPI := server.Group("/")
accountsJwtAuthAPI.Use(authMiddleware.TokenAuthMiddleware(jwt.ACCOUNTS.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
accountsJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.ACCOUNTS.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
passkeysHandler := &api.PasskeyHandler{
Controller: passkeyCtrl,
}
@ -531,7 +532,7 @@ func main() {
castCtrl := cast.NewController(&castDb, accessCtrl)
castMiddleware := middleware.CastMiddleware{CastCtrl: castCtrl, Cache: authCache}
castAPI.Use(castMiddleware.CastAuthMiddleware())
castAPI.Use(rateLimiter.GlobalRateLimiter(), castMiddleware.CastAuthMiddleware())
castHandler := &api.CastHandler{
CollectionCtrl: collectionController,

View file

@ -5,6 +5,8 @@ import (
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/ente-io/museum/pkg/controller/discord"
"github.com/ente-io/museum/pkg/utils/auth"
@ -20,14 +22,40 @@ type RateLimitMiddleware struct {
limit10ReqPerMin *limiter.Limiter
limit200ReqPerSec *limiter.Limiter
discordCtrl *discord.DiscordController
count int64 // Use int64 for atomic operations
limit int64
reset time.Duration
ticker *time.Ticker
}
func NewRateLimitMiddleware(discordCtrl *discord.DiscordController) *RateLimitMiddleware {
return &RateLimitMiddleware{
func NewRateLimitMiddleware(discordCtrl *discord.DiscordController, limit int64, reset time.Duration) *RateLimitMiddleware {
rl := &RateLimitMiddleware{
limit10ReqPerMin: rateLimiter("10-M"),
limit200ReqPerSec: rateLimiter("200-S"),
discordCtrl: discordCtrl,
limit: limit,
reset: reset,
ticker: time.NewTicker(reset),
}
go func() {
for range rl.ticker.C {
atomic.StoreInt64(&rl.count, 0) // Reset the count every reset interval
}
}()
return rl
}
// Increment increments the counter in a thread-safe manner.
// Returns true if the increment was within the rate limit, false if the rate limit was exceeded.
func (r *RateLimitMiddleware) Increment() bool {
// Atomically increment the count
newCount := atomic.AddInt64(&r.count, 1)
return newCount <= r.limit
}
// Stop the internal ticker, effectively stopping the rate limiter.
func (r *RateLimitMiddleware) Stop() {
r.ticker.Stop()
}
// rateLimiter will return instance of limiter.Limiter based on internal <limit>-<period>
@ -44,6 +72,18 @@ func rateLimiter(interval string) *limiter.Limiter {
return instance
}
// GlobalRateLimiter rate limits all requests to the server, regardless of the endpoint.
func (r *RateLimitMiddleware) GlobalRateLimiter() gin.HandlerFunc {
return func(c *gin.Context) {
if !r.Increment() {
go r.discordCtrl.NotifyPotentialAbuse("Global rate limit breached")
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit breached, try later"})
return
}
c.Next()
}
}
// APIRateLimitMiddleware only rate limits sensitive public endpoints which have a higher risk
// of abuse by any bad actor.
func (r *RateLimitMiddleware) APIRateLimitMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {