diff --git a/handler.go b/handler.go index b1018c7..26ba82a 100644 --- a/handler.go +++ b/handler.go @@ -36,7 +36,7 @@ func isSubdir(subdir, superdir string) (bool, error) { return false, nil } -func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, accessLogEntries chan LogEntry, wg *sync.WaitGroup) { +func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, accessLogEntries chan LogEntry, rl *RateLimiter, wg *sync.WaitGroup) { defer conn.Close() defer wg.Done() var tlsConn (*tls.Conn) = conn.(*tls.Conn) @@ -49,6 +49,15 @@ func handleGeminiRequest(conn net.Conn, sysConfig SysConfig, config UserConfig, defer func() { accessLogEntries <- logEntry }() } + // Enforce rate limiting + noPort := logEntry.RemoteAddr.String() + noPort = noPort[0:strings.LastIndex(noPort, ":")] + if !rl.Allowed(noPort) { + conn.Write([]byte("44 10 second cool down, please!\r\n")) + logEntry.Status = 44 + return + } + // Read request URL, err := readRequest(conn, &logEntry) if err != nil { diff --git a/launch.go b/launch.go index 5c14d52..8c302b1 100644 --- a/launch.go +++ b/launch.go @@ -159,11 +159,12 @@ func launch(sysConfig SysConfig, userConfig UserConfig, privInfo userInfo) int { // Infinite serve loop (SIGTERM breaks out) running := true var wg sync.WaitGroup + rl := newRateLimiter(100, 5) for running { conn, err := listener.Accept() if err == nil { wg.Add(1) - go handleGeminiRequest(conn, sysConfig, userConfig, accessLogEntries, &wg) + go handleGeminiRequest(conn, sysConfig, userConfig, accessLogEntries, &rl, &wg) } else { select { case <-shutdown: diff --git a/ratelim.go b/ratelim.go new file mode 100644 index 0000000..6aa08dc --- /dev/null +++ b/ratelim.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "sync" + "time" +) + +type RateLimiter struct { + mu sync.Mutex + bucket map[string]int + capacity int + rate int +} + +func newRateLimiter(capacity int, rate int) RateLimiter { + var rl = new(RateLimiter) + rl.bucket = make(map[string]int) + rl.capacity = capacity + rl.rate = rate + // Leak periodically + go func () { + for(true) { + fmt.Println(rl.bucket) + rl.mu.Lock() + for addr, drips := range rl.bucket { + if drips <= rate { + delete(rl.bucket, addr) + } else { + rl.bucket[addr] = drips - rl.rate + } + } + rl.mu.Unlock() + time.Sleep(time.Second) + } + }() + return *rl +} + +func (rl *RateLimiter) Allowed(addr string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + drips, present := rl.bucket[addr] + if !present { + rl.bucket[addr] = 1 + return true + } + if drips == rl.capacity { + return false + } + rl.bucket[addr] = drips + 1 + return true +} +