From efde852c54f22e09da962acefff8ad0477b2f628 Mon Sep 17 00:00:00 2001 From: Solderpunk Date: Sat, 18 Mar 2023 16:40:23 +0100 Subject: [PATCH] Refactor rate limiting to have soft and hard limits, block clients exceeding hard limits for one hour. --- config.go | 6 ++++-- handler.go | 10 +++++++--- launch.go | 6 ++++-- ratelim.go | 38 ++++++++++++++++++++++++++++++++------ 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 8e2e749..6d0c91f 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,8 @@ type SysConfig struct { AllowTLS12 bool RateLimitEnable bool RateLimitAverage int - RateLimitBurst int + RateLimitSoft int + RateLimitHard int } type UserConfig struct { @@ -61,7 +62,8 @@ func getConfig(filename string) (SysConfig, UserConfig, error) { sysConfig.AllowTLS12 = true sysConfig.RateLimitEnable = false sysConfig.RateLimitAverage = 1 - sysConfig.RateLimitBurst = 10 + sysConfig.RateLimitSoft = 10 + sysConfig.RateLimitHard = 50 userConfig.GeminiExt = "gmi" userConfig.DefaultLang = "" diff --git a/handler.go b/handler.go index afd58a6..093306d 100644 --- a/handler.go +++ b/handler.go @@ -53,9 +53,13 @@ func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, if sysConfig.RateLimitEnable { noPort := logEntry.RemoteAddr.String() noPort = noPort[0:strings.LastIndex(noPort, ":")] - drips, allowed := rl.Allowed(noPort) - if !allowed { - conn.Write([]byte("44 " + strconv.Itoa(drips) + " second cool down, please!\r\n")) + limited := rl.hardLimited(noPort) + if limited { + conn.Close() + } + delay, limited := rl.softLimited(noPort) + if limited { + conn.Write([]byte("44 " + strconv.Itoa(delay) + " second cool down, please!\r\n")) logEntry.Status = 44 return } diff --git a/launch.go b/launch.go index 1dcaba3..3738f4b 100644 --- a/launch.go +++ b/launch.go @@ -140,7 +140,9 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { go func() { for { entry := <-accessLogEntries - writeLogEntry(accessLogFile, entry) + if entry.Status != 0 { + writeLogEntry(accessLogFile, entry) + } } }() } @@ -159,7 +161,7 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { // Infinite serve loop (SIGTERM breaks out) running := true var wg sync.WaitGroup - rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitBurst) + rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitSoft, sysConfig.RateLimitHard) for running { conn, err := listener.Accept() if err == nil { diff --git a/ratelim.go b/ratelim.go index 9c97ec9..8ca2f1f 100644 --- a/ratelim.go +++ b/ratelim.go @@ -1,6 +1,7 @@ package main import ( + "log" "sync" "time" ) @@ -8,19 +9,25 @@ import ( type RateLimiter struct { mu sync.Mutex bucket map[string]int + bans map[string]time.Time rate int - burst int + softLimit int + hardLimit int } -func newRateLimiter(rate int, burst int) RateLimiter { +func newRateLimiter(rate int, softLimit int, hardLimit int) RateLimiter { var rl = new(RateLimiter) rl.bucket = make(map[string]int) + rl.bans = make(map[string]time.Time) rl.rate = rate - rl.burst = burst + rl.softLimit = softLimit + rl.hardLimit = hardLimit + // Leak periodically go func () { for(true) { rl.mu.Lock() + // Leak the buckets for addr, drips := range rl.bucket { if drips <= rate { delete(rl.bucket, addr) @@ -28,6 +35,15 @@ func newRateLimiter(rate int, burst int) RateLimiter { rl.bucket[addr] = drips - rl.rate } } + // Expire bans + now := time.Now() + for addr, expiry := range rl.bans { + if now.After(expiry) { + delete(rl.bans, addr) + } + } + + // Wait rl.mu.Unlock() time.Sleep(time.Second) } @@ -35,16 +51,26 @@ func newRateLimiter(rate int, burst int) RateLimiter { return *rl } -func (rl *RateLimiter) Allowed(addr string) (int, bool) { +func (rl *RateLimiter) softLimited(addr string) (int, bool) { rl.mu.Lock() defer rl.mu.Unlock() drips, present := rl.bucket[addr] if !present { rl.bucket[addr] = 1 - return 1, true + return 1, false } drips += 1 rl.bucket[addr] = drips - return drips, drips < rl.burst + if drips > rl.hardLimit { + now := time.Now() + expiry := now.Add(time.Hour) + rl.bans[addr] = expiry + log.Println("Banning " + addr + "for 1 hour due to ignoring rate limiting.") + } + return drips, drips > rl.softLimit } +func (rl *RateLimiter) hardLimited(addr string) bool { + _, present := rl.bans[addr] + return present +}