diff --git a/config.go b/config.go index 4f28699..6d0c91f 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,10 @@ type SysConfig struct { SCGIPaths map[string]string ReadMollyFiles bool AllowTLS12 bool + RateLimitEnable bool + RateLimitAverage int + RateLimitSoft int + RateLimitHard int } type UserConfig struct { @@ -56,6 +60,10 @@ func getConfig(filename string) (SysConfig, UserConfig, error) { sysConfig.SCGIPaths = make(map[string]string) sysConfig.ReadMollyFiles = false sysConfig.AllowTLS12 = true + sysConfig.RateLimitEnable = false + sysConfig.RateLimitAverage = 1 + sysConfig.RateLimitSoft = 10 + sysConfig.RateLimitHard = 50 userConfig.GeminiExt = "gmi" userConfig.DefaultLang = "" diff --git a/handler.go b/handler.go index 08f1499..ee961f0 100644 --- a/handler.go +++ b/handler.go @@ -36,7 +36,7 @@ func isSubdir(subdir, superdir string) (bool, error) { return false, nil } -func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, accessLogEntries chan LogEntry, wg *sync.WaitGroup) { +func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, accessLogEntries chan LogEntry, rl *RateLimiter, wg *sync.WaitGroup) { defer conn.Close() defer wg.Done() var tlsConn (*tls.Conn) = conn.(*tls.Conn) @@ -49,6 +49,23 @@ func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, defer func() { accessLogEntries <- logEntry }() } + // Enforce rate limiting + if sysConfig.RateLimitEnable { + noPort := logEntry.RemoteAddr.String() + noPort = noPort[0:strings.LastIndex(noPort, ":")] + limited := rl.hardLimited(noPort) + if limited { + conn.Close() + return + } + delay, limited := rl.softLimited(noPort) + if limited { + conn.Write([]byte("44 " + strconv.Itoa(delay) + " second cool down, please!\r\n")) + logEntry.Status = 44 + return + } + } + // Read request URL, err := readRequest(conn, &logEntry) if err != nil { diff --git a/launch.go b/launch.go index 5c14d52..3738f4b 100644 --- a/launch.go +++ b/launch.go @@ -140,7 +140,9 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { go func() { for { entry := <-accessLogEntries - writeLogEntry(accessLogFile, entry) + if entry.Status != 0 { + writeLogEntry(accessLogFile, entry) + } } }() } @@ -159,11 +161,12 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { // Infinite serve loop (SIGTERM breaks out) running := true var wg sync.WaitGroup + rl := newRateLimiter(sysConfig.RateLimitAverage, sysConfig.RateLimitSoft, sysConfig.RateLimitHard) for running { conn, err := listener.Accept() if err == nil { wg.Add(1) - go handleGeminiRequest(conn, sysConfig, userConfig, accessLogEntries, &wg) + go handleGeminiRequest(conn, sysConfig, userConfig, accessLogEntries, &rl, &wg) } else { select { case <-shutdown: diff --git a/ratelim.go b/ratelim.go new file mode 100644 index 0000000..92a8ccd --- /dev/null +++ b/ratelim.go @@ -0,0 +1,87 @@ +package main + +import ( + "log" + "sync" + "strconv" + "time" +) + +type RateLimiter struct { + mu sync.Mutex + bucket map[string]int + bans map[string]time.Time + banCounts map[string]int + rate int + softLimit int + hardLimit int +} + +func newRateLimiter(rate int, softLimit int, hardLimit int) RateLimiter { + var rl = new(RateLimiter) + rl.bucket = make(map[string]int) + rl.bans = make(map[string]time.Time) + rl.banCounts = make(map[string]int) + rl.rate = rate + rl.softLimit = softLimit + rl.hardLimit = hardLimit + + // Leak periodically + go func () { + for(true) { + rl.mu.Lock() + // Leak the buckets + for addr, drips := range rl.bucket { + if drips <= rate { + delete(rl.bucket, addr) + } else { + rl.bucket[addr] = drips - rl.rate + } + } + // Expire bans + now := time.Now() + for addr, expiry := range rl.bans { + if now.After(expiry) { + delete(rl.bans, addr) + } + } + + // Wait + rl.mu.Unlock() + time.Sleep(time.Second) + } + }() + return *rl +} + +func (rl *RateLimiter) softLimited(addr string) (int, bool) { + rl.mu.Lock() + defer rl.mu.Unlock() + drips, present := rl.bucket[addr] + if !present { + rl.bucket[addr] = 1 + return 1, false + } + drips += 1 + rl.bucket[addr] = drips + if drips > rl.hardLimit { + banCount, present := rl.banCounts[addr] + if present { + banCount += 1 + } else { + banCount = 1 + } + rl.banCounts[addr] = banCount + banDuration := 1 << (banCount - 1) + now := time.Now() + expiry := now.Add(time.Duration(banDuration)*time.Hour) + rl.bans[addr] = expiry + log.Println("Banning " + addr + " for " + strconv.Itoa(banDuration) + " hours due to ignoring rate limiting.") + } + return drips, drips > rl.softLimit +} + +func (rl *RateLimiter) hardLimited(addr string) bool { + _, present := rl.bans[addr] + return present +}