From 137be932972188429369a26ed7ab5314d1d015c9 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 9 Sep 2020 14:08:20 +0200 Subject: [PATCH] Add `tls ca` directive --- main.go | 40 ++++++++++++++++++++++++++++++---------- server.go | 28 +++++++++++++++++----------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/main.go b/main.go index dfb7108..cd5438e 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "log" "net" @@ -17,9 +16,18 @@ func main() { srv := NewServer() - for _, d := range cfg.ChildrenByName("frontend") { - if err := parseFrontend(srv, d); err != nil { - log.Fatalf("failed to parse frontend: %v", err) + for _, d := range cfg.Children { + var err error + switch d.Name { + case "frontend": + err = parseFrontend(srv, d) + case "tls": + err = parseTLS(srv, d) + default: + log.Fatalf("unknown %q directive", d.Name) + } + if err != nil { + log.Fatalf("directive %q: %v", d.Name, err) } } @@ -43,7 +51,6 @@ func parseFrontend(srv *Server, d *Directive) error { return err } - var certNames []string for _, listenAddr := range d.Params { host, port, err := net.SplitHostPort(listenAddr) if err != nil { @@ -54,8 +61,9 @@ func parseFrontend(srv *Server, d *Directive) error { var name string if host != "" && host != "localhost" && net.ParseIP(host) == nil { name = host - certNames = append(certNames, host) host = "" + + srv.ManagedNames = append(srv.ManagedNames, name) } addr := net.JoinHostPort(host, port) @@ -66,10 +74,6 @@ func parseFrontend(srv *Server, d *Directive) error { } } - if err := srv.certmagic.ManageAsync(context.Background(), certNames); err != nil { - return fmt.Errorf("failed to manage TLS certificates: %v", err) - } - return nil } @@ -103,3 +107,19 @@ func parseBackend(backend *Backend, d *Directive) error { return nil } + +func parseTLS(srv *Server, d *Directive) error { + for _, child := range d.Children { + switch child.Name { + case "ca": + var caURL string + if err := child.ParseParams(&caURL); err != nil { + return err + } + srv.acmeManager.CA = caURL + default: + return fmt.Errorf("unknown %q directive", child.Name) + } + } + return nil +} diff --git a/server.go b/server.go index e07ae70..91af70e 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "fmt" "io" @@ -11,27 +12,28 @@ import ( ) type Server struct { - Listeners map[string]*Listener // indexed by listening address - Frontends []*Frontend - certmagic *certmagic.Config + Listeners map[string]*Listener // indexed by listening address + Frontends []*Frontend + ManagedNames []string + + acmeManager *certmagic.ACMEManager + certmagic *certmagic.Config } func NewServer() *Server { cfg := certmagic.NewDefault() - acme := certmagic.DefaultACME - // TODO: use production CA - acme.CA = certmagic.LetsEncryptStagingCA - acme.Agreed = true + mgr := certmagic.NewACMEManager(cfg, certmagic.DefaultACME) + mgr.Agreed = true // TODO: enable HTTP challenge by peeking incoming requests on port 80 - acme.DisableHTTPChallenge = true - mgr := certmagic.NewACMEManager(cfg, acme) + mgr.DisableHTTPChallenge = true cfg.Issuer = mgr cfg.Revoker = mgr return &Server{ - Listeners: make(map[string]*Listener), - certmagic: cfg, + Listeners: make(map[string]*Listener), + acmeManager: mgr, + certmagic: cfg, } } @@ -46,6 +48,10 @@ func (srv *Server) RegisterListener(addr string) *Listener { } func (srv *Server) Start() error { + if err := srv.certmagic.ManageAsync(context.Background(), srv.ManagedNames); err != nil { + return fmt.Errorf("failed to manage TLS certificates: %v", err) + } + for _, ln := range srv.Listeners { if err := ln.Start(); err != nil { return err