diff --git a/handler.go b/handler.go index 8f339bf..9971b33 100644 --- a/handler.go +++ b/handler.go @@ -15,11 +15,13 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) -func handleGeminiRequest(conn net.Conn, config Config, accessLogEntries chan LogEntry, errorLog *log.Logger) { +func handleGeminiRequest(conn net.Conn, config Config, accessLogEntries chan LogEntry, errorLog *log.Logger, wg *sync.WaitGroup) { defer conn.Close() + defer wg.Done() var tlsConn (*tls.Conn) = conn.(*tls.Conn) var log LogEntry log.Time = time.Now() diff --git a/main.go b/main.go index fd43fb3..db197d3 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,10 @@ import ( "fmt" "log" "os" + "os/signal" "strconv" + "sync" + "syscall" ) var VERSION = "0.0.0" @@ -109,14 +112,36 @@ func main() { // Restrict access to the files specified in config enableSecurityRestrictions(config, errorLog) - // Infinite serve loop - for { + // Start listening for signals + shutdown := make(chan struct{}) + sigterm := make(chan os.Signal, 1) + signal.Notify(sigterm, syscall.SIGTERM) + go func() { + <-sigterm + errorLog.Println("Caught SIGTERM. Waiting for handlers to finish...") + close(shutdown) + listener.Close() + }() + + // Infinite serve loop (SIGTERM breaks out) + running := true + var wg sync.WaitGroup + for running { conn, err := listener.Accept() - if err != nil { - errorLog.Println("Error accepting connection: " + err.Error()) - log.Fatal(err) + if err == nil { + wg.Add(1) + go handleGeminiRequest(conn, config, accessLogEntries, errorLog, &wg) + } else { + select { + case <-shutdown: + running = false + default: + errorLog.Println("Error accepting connection: " + err.Error()) + } } - go handleGeminiRequest(conn, config, accessLogEntries, errorLog) } + // Wait for still-running handler Go routines to finish + wg.Wait() + errorLog.Println("Exiting.") }