1
1
Fork 0
mirror of https://tildegit.org/solderpunk/molly-brown synced 2024-05-12 16:06:03 +02:00

Refactor rate limiting to have soft and hard limits, block clients exceeding hard limits for one hour.

This commit is contained in:
Solderpunk 2023-03-18 16:40:23 +01:00
parent 3c5835f033
commit efde852c54
4 changed files with 47 additions and 13 deletions

View File

@ -24,7 +24,8 @@ type SysConfig struct {
AllowTLS12 bool AllowTLS12 bool
RateLimitEnable bool RateLimitEnable bool
RateLimitAverage int RateLimitAverage int
RateLimitBurst int RateLimitSoft int
RateLimitHard int
} }
type UserConfig struct { type UserConfig struct {
@ -61,7 +62,8 @@ func getConfig(filename string) (SysConfig, UserConfig, error) {
sysConfig.AllowTLS12 = true sysConfig.AllowTLS12 = true
sysConfig.RateLimitEnable = false sysConfig.RateLimitEnable = false
sysConfig.RateLimitAverage = 1 sysConfig.RateLimitAverage = 1
sysConfig.RateLimitBurst = 10 sysConfig.RateLimitSoft = 10
sysConfig.RateLimitHard = 50
userConfig.GeminiExt = "gmi" userConfig.GeminiExt = "gmi"
userConfig.DefaultLang = "" userConfig.DefaultLang = ""

View File

@ -53,9 +53,13 @@ func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig,
if sysConfig.RateLimitEnable { if sysConfig.RateLimitEnable {
noPort := logEntry.RemoteAddr.String() noPort := logEntry.RemoteAddr.String()
noPort = noPort[0:strings.LastIndex(noPort, ":")] noPort = noPort[0:strings.LastIndex(noPort, ":")]
drips, allowed := rl.Allowed(noPort) limited := rl.hardLimited(noPort)
if !allowed { if limited {
conn.Write([]byte("44 " + strconv.Itoa(drips) + " second cool down, please!\r\n")) conn.Close()
}
delay, limited := rl.softLimited(noPort)
if limited {
conn.Write([]byte("44 " + strconv.Itoa(delay) + " second cool down, please!\r\n"))
logEntry.Status = 44 logEntry.Status = 44
return return
} }

View File

@ -140,7 +140,9 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int {
go func() { go func() {
for { for {
entry := <-accessLogEntries entry := <-accessLogEntries
writeLogEntry(accessLogFile, entry) if entry.Status != 0 {
writeLogEntry(accessLogFile, entry)
}
} }
}() }()
} }
@ -159,7 +161,7 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int {
// Infinite serve loop (SIGTERM breaks out) // Infinite serve loop (SIGTERM breaks out)
running := true running := true
var wg sync.WaitGroup var wg sync.WaitGroup
rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitBurst) rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitSoft, sysConfig.RateLimitHard)
for running { for running {
conn, err := listener.Accept() conn, err := listener.Accept()
if err == nil { if err == nil {

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"log"
"sync" "sync"
"time" "time"
) )
@ -8,19 +9,25 @@ import (
type RateLimiter struct { type RateLimiter struct {
mu sync.Mutex mu sync.Mutex
bucket map[string]int bucket map[string]int
bans map[string]time.Time
rate int rate int
burst int softLimit int
hardLimit int
} }
func newRateLimiter(rate int, burst int) RateLimiter { func newRateLimiter(rate int, softLimit int, hardLimit int) RateLimiter {
var rl = new(RateLimiter) var rl = new(RateLimiter)
rl.bucket = make(map[string]int) rl.bucket = make(map[string]int)
rl.bans = make(map[string]time.Time)
rl.rate = rate rl.rate = rate
rl.burst = burst rl.softLimit = softLimit
rl.hardLimit = hardLimit
// Leak periodically // Leak periodically
go func () { go func () {
for(true) { for(true) {
rl.mu.Lock() rl.mu.Lock()
// Leak the buckets
for addr, drips := range rl.bucket { for addr, drips := range rl.bucket {
if drips <= rate { if drips <= rate {
delete(rl.bucket, addr) delete(rl.bucket, addr)
@ -28,6 +35,15 @@ func newRateLimiter(rate int, burst int) RateLimiter {
rl.bucket[addr] = drips - rl.rate rl.bucket[addr] = drips - rl.rate
} }
} }
// Expire bans
now := time.Now()
for addr, expiry := range rl.bans {
if now.After(expiry) {
delete(rl.bans, addr)
}
}
// Wait
rl.mu.Unlock() rl.mu.Unlock()
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -35,16 +51,26 @@ func newRateLimiter(rate int, burst int) RateLimiter {
return *rl return *rl
} }
func (rl *RateLimiter) Allowed(addr string) (int, bool) { func (rl *RateLimiter) softLimited(addr string) (int, bool) {
rl.mu.Lock() rl.mu.Lock()
defer rl.mu.Unlock() defer rl.mu.Unlock()
drips, present := rl.bucket[addr] drips, present := rl.bucket[addr]
if !present { if !present {
rl.bucket[addr] = 1 rl.bucket[addr] = 1
return 1, true return 1, false
} }
drips += 1 drips += 1
rl.bucket[addr] = drips rl.bucket[addr] = drips
return drips, drips < rl.burst if drips > rl.hardLimit {
now := time.Now()
expiry := now.Add(time.Hour)
rl.bans[addr] = expiry
log.Println("Banning " + addr + "for 1 hour due to ignoring rate limiting.")
}
return drips, drips > rl.softLimit
} }
func (rl *RateLimiter) hardLimited(addr string) bool {
_, present := rl.bans[addr]
return present
}