From 4548a7fe655704c93f11b39b0886dfddbc3ea1ce Mon Sep 17 00:00:00 2001 From: minus Date: Tue, 22 Dec 2020 12:06:14 +0100 Subject: [PATCH] Add config reloading Instead of updating the configuration, we configure a new Server instance and then migrate Listeners that still exist to it. Open client connections are left completely untouched. Closes https://todo.sr.ht/~emersion/tlstunnel/1 --- cmd/tlstunnel/main.go | 53 +++++++++++++++--- server.go | 124 +++++++++++++++++++++++++++++++++++------- tlstunnel.1.scd | 2 + 3 files changed, 152 insertions(+), 27 deletions(-) diff --git a/cmd/tlstunnel/main.go b/cmd/tlstunnel/main.go index f4ba7ef..5f04c86 100644 --- a/cmd/tlstunnel/main.go +++ b/cmd/tlstunnel/main.go @@ -2,7 +2,11 @@ package main import ( "flag" + "fmt" "log" + "os" + "os/signal" + "syscall" "git.sr.ht/~emersion/go-scfg" "git.sr.ht/~emersion/tlstunnel" @@ -15,13 +19,10 @@ var ( certDataPath = "" ) -func main() { - flag.StringVar(&configPath, "config", configPath, "path to configuration file") - flag.Parse() - +func newServer() (*tlstunnel.Server, error) { cfg, err := scfg.Load(configPath) if err != nil { - log.Fatalf("failed to load config file: %v", err) + return nil, fmt.Errorf("failed to load config file: %w", err) } srv := tlstunnel.NewServer() @@ -37,7 +38,7 @@ func main() { } logger, err := loggerCfg.Build() if err != nil { - log.Fatalf("failed to initialize zap logger: %v", err) + return nil, fmt.Errorf("failed to initialize zap logger: %w", err) } srv.ACMEConfig.Logger = logger srv.ACMEManager.Logger = logger @@ -47,12 +48,48 @@ func main() { } if err := srv.Load(cfg); err != nil { - log.Fatal(err) + return nil, err } + return srv, nil +} + +func main() { + flag.StringVar(&configPath, "config", configPath, "path to configuration file") + flag.Parse() + + srv, err := newServer() + if err != nil { + log.Fatalf("failed to create server: %v", err) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + if err := srv.Start(); err != nil { log.Fatal(err) } - select {} + for sig := range sigCh { + switch sig { + case syscall.SIGINT: + case syscall.SIGTERM: + srv.Stop() + return + case syscall.SIGHUP: + log.Print("caught SIGHUP, reloading config") + newSrv, err := newServer() + if err != nil { + log.Printf("reload failed: %v", err) + continue + } + err = newSrv.Replace(srv) + if err != nil { + log.Printf("reload failed: %v", err) + continue + } + srv = newSrv + log.Print("successfully reloaded config") + } + } } diff --git a/server.go b/server.go index b79d1fe..7fa7c08 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "log" "net" "strings" + "sync/atomic" "git.sr.ht/~emersion/go-scfg" "github.com/caddyserver/certmagic" @@ -24,6 +25,8 @@ type Server struct { ACMEManager *certmagic.ACMEManager ACMEConfig *certmagic.Config + + cancelACME context.CancelFunc } func NewServer() *Server { @@ -57,17 +60,28 @@ func (srv *Server) RegisterListener(addr string) *Listener { return ln } -func (srv *Server) Start() error { +func (srv *Server) startACME() error { + var ctx context.Context + ctx, srv.cancelACME = context.WithCancel(context.Background()) + for _, cert := range srv.UnmanagedCerts { if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil { return err } } - if err := srv.ACMEConfig.ManageAsync(context.Background(), srv.ManagedNames); err != nil { + if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil { return fmt.Errorf("failed to manage TLS certificates: %v", err) } + return nil +} + +func (srv *Server) Start() error { + if err := srv.startACME(); err != nil { + return err + } + for _, ln := range srv.Listeners { if err := ln.Start(); err != nil { return err @@ -76,37 +90,94 @@ func (srv *Server) Start() error { return nil } -type Listener struct { - Address string +func (srv *Server) Stop() { + srv.cancelACME() + // TODO: clean cached unmanaged certs + for _, ln := range srv.Listeners { + ln.Stop() + } +} + +// Replace starts the server but takes over existing listeners from an old +// Server instance. The old instance keeps running unchanged if Replace +// returns an error. +func (srv *Server) Replace(old *Server) error { + // Try to start new listeners + for addr, ln := range srv.Listeners { + if _, ok := old.Listeners[addr]; ok { + continue + } + if err := ln.Start(); err != nil { + for _, ln2 := range srv.Listeners { + ln2.Stop() + } + return err + } + } + + // Restart ACME + old.cancelACME() + if err := srv.startACME(); err != nil { + for _, ln2 := range srv.Listeners { + ln2.Stop() + } + return err + } + // TODO: clean cached unmanaged certs + + // Take over existing listeners and terminate old ones + for addr, oldLn := range old.Listeners { + if ln, ok := srv.Listeners[addr]; ok { + srv.Listeners[addr] = oldLn.UpdateFrom(ln) + } else { + oldLn.Stop() + } + } + + return nil +} + +type listenerHandles struct { Server *Server Frontends map[string]*Frontend // indexed by server name } +type Listener struct { + Address string + netLn net.Listener + atomic atomic.Value +} + func newListener(srv *Server, addr string) *Listener { - return &Listener{ - Address: addr, + ln := &Listener{ + Address: addr, + } + ln.atomic.Store(&listenerHandles{ Server: srv, Frontends: make(map[string]*Frontend), - } + }) + return ln } func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error { - if _, ok := ln.Frontends[name]; ok { + fes := ln.atomic.Load().(*listenerHandles).Frontends + if _, ok := fes[name]; ok { return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name) } - ln.Frontends[name] = fe + fes[name] = fe return nil } func (ln *Listener) Start() error { - netLn, err := net.Listen("tcp", ln.Address) + var err error + ln.netLn, err = net.Listen("tcp", ln.Address) if err != nil { return err } log.Printf("listening on %q", ln.Address) go func() { - if err := ln.serve(netLn); err != nil { + if err := ln.serve(); err != nil { log.Fatalf("listener %q: %v", ln.Address, err) } }() @@ -114,10 +185,22 @@ func (ln *Listener) Start() error { return nil } -func (ln *Listener) serve(netLn net.Listener) error { +func (ln *Listener) Stop() { + ln.netLn.Close() +} + +func (ln *Listener) UpdateFrom(new *Listener) *Listener { + ln.atomic.Store(new.atomic.Load()) + return ln +} + +func (ln *Listener) serve() error { for { - conn, err := netLn.Accept() - if err != nil { + conn, err := ln.netLn.Accept() + if err != nil && strings.Contains(err.Error(), "use of closed network connection") { + // Listening socket has been closed by Stop() + return nil + } else if err != nil { return fmt.Errorf("failed to accept connection: %v", err) } @@ -131,9 +214,10 @@ func (ln *Listener) serve(netLn net.Listener) error { func (ln *Listener) handle(conn net.Conn) error { defer conn.Close() + srv := ln.atomic.Load().(*listenerHandles).Server // TODO: setup timeouts - tlsConfig := ln.Server.ACMEConfig.TLSConfig() + tlsConfig := srv.ACMEConfig.TLSConfig() getConfigForClient := tlsConfig.GetConfigForClient tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { // Call previous GetConfigForClient function, if any @@ -145,7 +229,7 @@ func (ln *Listener) handle(conn net.Conn) error { return nil, err } } else { - tlsConfig = ln.Server.ACMEConfig.TLSConfig() + tlsConfig = srv.ACMEConfig.TLSConfig() } fe, err := ln.matchFrontend(hello.ServerName) @@ -171,18 +255,20 @@ func (ln *Listener) handle(conn net.Conn) error { } func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) { - fe, ok := ln.Frontends[serverName] + fes := ln.atomic.Load().(*listenerHandles).Frontends + + fe, ok := fes[serverName] if !ok { // Match wildcard certificates, allowing only a single, non-partial // wildcard, in the left-most label i := strings.IndexByte(serverName, '.') // Don't allow wildcards with only a TLD (e.g. *.com) if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 { - fe, ok = ln.Frontends["*"+serverName[i:]] + fe, ok = fes["*"+serverName[i:]] } } if !ok { - fe, ok = ln.Frontends[""] + fe, ok = fes[""] } if !ok { return nil, fmt.Errorf("can't find frontend for server name %q", serverName) diff --git a/tlstunnel.1.scd b/tlstunnel.1.scd index 30ee269..b4c409a 100644 --- a/tlstunnel.1.scd +++ b/tlstunnel.1.scd @@ -27,6 +27,8 @@ The config file has one directive per line. Directives have a name, followed by parameters separated by space characters. Directives may have children in blocks delimited by "{" and "}". Lines beginning with "#" are comments. +tlstunnel will reload the config file when it receives the HUP signal. + Example: ```