diff --git a/config.go b/config.go index 4f28699..8e2e749 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,9 @@ type SysConfig struct { SCGIPaths map[string]string ReadMollyFiles bool AllowTLS12 bool + RateLimitEnable bool + RateLimitAverage int + RateLimitBurst int } type UserConfig struct { @@ -56,6 +59,9 @@ func getConfig(filename string) (SysConfig, UserConfig, error) { sysConfig.SCGIPaths = make(map[string]string) sysConfig.ReadMollyFiles = false sysConfig.AllowTLS12 = true + sysConfig.RateLimitEnable = false + sysConfig.RateLimitAverage = 1 + sysConfig.RateLimitBurst = 10 userConfig.GeminiExt = "gmi" userConfig.DefaultLang = "" diff --git a/handler.go b/handler.go index 26ba82a..4f4515d 100644 --- a/handler.go +++ b/handler.go @@ -50,12 +50,14 @@ func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, } // Enforce rate limiting - noPort := logEntry.RemoteAddr.String() - noPort = noPort[0:strings.LastIndex(noPort, ":")] - if !rl.Allowed(noPort) { - conn.Write([]byte("44 10 second cool down, please!\r\n")) - logEntry.Status = 44 - return + if sysConfig.RateLimitEnable { + noPort := logEntry.RemoteAddr.String() + noPort = noPort[0:strings.LastIndex(noPort, ":")] + if !rl.Allowed(noPort) { + conn.Write([]byte("44 10 second cool down, please!\r\n")) + logEntry.Status = 44 + return + } } // Read request diff --git a/launch.go b/launch.go index 8c302b1..1dcaba3 100644 --- a/launch.go +++ b/launch.go @@ -159,7 +159,7 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { // Infinite serve loop (SIGTERM breaks out) running := true var wg sync.WaitGroup - rl := newRateLimiter(100, 5) + rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitBurst) for running { conn, err := listener.Accept() if err == nil { diff --git a/ratelim.go b/ratelim.go index 6aa08dc..ff399f3 100644 --- a/ratelim.go +++ b/ratelim.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "sync" "time" ) @@ -9,19 +8,18 @@ import ( type RateLimiter struct { mu sync.Mutex bucket map[string]int - capacity int rate int + burst int } -func newRateLimiter(capacity int, rate int) RateLimiter { +func newRateLimiter(rate int, burst int) RateLimiter { var rl = new(RateLimiter) rl.bucket = make(map[string]int) - rl.capacity = capacity rl.rate = rate + rl.burst = burst // Leak periodically go func () { for(true) { - fmt.Println(rl.bucket) rl.mu.Lock() for addr, drips := range rl.bucket { if drips <= rate { @@ -45,7 +43,7 @@ func (rl *RateLimiter) Allowed(addr string) bool { rl.bucket[addr] = 1 return true } - if drips == rl.capacity { + if drips == rl.burst { return false } rl.bucket[addr] = drips + 1