From b796f72abceee0227fd93b64dae630b22c90c090 Mon Sep 17 00:00:00 2001 From: surtur Date: Tue, 2 Apr 2024 23:32:51 +0200 Subject: [PATCH] go: add a 404 handler --- main.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index b540515..627e81a 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,12 @@ package main import ( "embed" + "fmt" "io/fs" "log" "net/http" + "os" + "path" "time" ) @@ -13,19 +16,65 @@ var version = "development" //go:embed public/* var embeddedPublic embed.FS +// bytes of the 404.html page. +var b404 []byte + +// usrIP is a good way to read the src IP this way because we trust our proxy. +// largely from: https://stackoverflow.com/a/55738279 +func usrIP(r *http.Request) string { + ip := r.Header.Get("x-forwarded-for") + if ip == "" { + ip = r.Header.Get("x-real-ip") + } + + if ip == "" { + ip = r.RemoteAddr + } + + return ip +} + +// notFound writes back 404 and the 404 page. +func notFound(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write(b404) +} + +// handleNotFound allows us to override the response on e.g. 404. +// inspired by https://stackoverflow.com/a/62747667 +func handleNotFound(fs http.FileSystem) http.Handler { + fileServer := http.FileServer(fs) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p := r.URL.Path + // so as not to allow path traversals. + cleanedPath := path.Clean(p) + _, err := fs.Open(cleanedPath) + if os.IsNotExist(err) { + log.Printf("Error 404 Not Found when serving path: %s, cleaned path: %s, IP: %s", + p, cleanedPath, usrIP(r)) + notFound(w, r) + return + } + fileServer.ServeHTTP(w, r) + }) +} + func main() { // TODO: ENV WHATPORT - // TODO: handler for / + 404 + // TODO: add /ip endpoint that returns the src IP. + f404, err := embeddedPublic.ReadFile("public/404.html") + if err != nil { + log.Fatalf("no 404.html in the folder, weird: %s", fmt.Errorf("err: %w", err)) + } + + b404 = f404 + root, err := fs.Sub(embeddedPublic, "public") if err != nil { log.Fatal(err) } - fs := http.FileServer(http.FS(root)) - - http.Handle("/", fs) - - log.Printf("app built from revision '%s'\n", version) + log.Printf("Starting app built from revision '%s'\n", version) log.Print("Listening on :1314...") // https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/ @@ -33,7 +82,7 @@ func main() { ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, Addr: ":1314", - Handler: nil, + Handler: handleNotFound(http.FS(root)), } err = srv.ListenAndServe()