diff --git a/server.go b/server.go index 7fa7c08..277442d 100644 --- a/server.go +++ b/server.go @@ -16,6 +16,21 @@ import ( "github.com/pires/go-proxyproto/tlvparse" ) +type acmeCache struct { + config *certmagic.Config + cache *certmagic.Cache +} + +func newACMECache() *acmeCache { + cache := &acmeCache{} + cache.cache = certmagic.NewCache(certmagic.CacheOptions{ + GetConfigForCert: func(certmagic.Certificate) (*certmagic.Config, error) { + return cache.config, nil + }, + }) + return cache +} + type Server struct { Listeners map[string]*Listener // indexed by listening address Frontends []*Frontend @@ -26,23 +41,23 @@ type Server struct { ACMEManager *certmagic.ACMEManager ACMEConfig *certmagic.Config + acmeCache *acmeCache cancelACME context.CancelFunc } func NewServer() *Server { - cfg := certmagic.NewDefault() + // Make a copy of the defaults + acmeConfig := certmagic.Default + acmeManager := certmagic.DefaultACME - mgr := certmagic.NewACMEManager(cfg, certmagic.DefaultACME) - mgr.Agreed = true + acmeManager.Agreed = true // We're a TLS server, we don't speak HTTP - mgr.DisableHTTPChallenge = true - cfg.Issuer = mgr - cfg.Revoker = mgr + acmeManager.DisableHTTPChallenge = true return &Server{ Listeners: make(map[string]*Listener), - ACMEManager: mgr, - ACMEConfig: cfg, + ACMEManager: &acmeManager, + ACMEConfig: &acmeConfig, } } @@ -64,6 +79,14 @@ func (srv *Server) startACME() error { var ctx context.Context ctx, srv.cancelACME = context.WithCancel(context.Background()) + srv.ACMEConfig = certmagic.New(srv.acmeCache.cache, *srv.ACMEConfig) + srv.ACMEManager = certmagic.NewACMEManager(srv.ACMEConfig, *srv.ACMEManager) + + srv.ACMEConfig.Issuer = srv.ACMEManager + srv.ACMEConfig.Revoker = srv.ACMEManager + + srv.acmeCache.config = srv.ACMEConfig + for _, cert := range srv.UnmanagedCerts { if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil { return err @@ -78,6 +101,8 @@ func (srv *Server) startACME() error { } func (srv *Server) Start() error { + srv.acmeCache = newACMECache() + if err := srv.startACME(); err != nil { return err } @@ -115,6 +140,9 @@ func (srv *Server) Replace(old *Server) error { } } + // Steal the old server's ACME cache + srv.acmeCache = old.acmeCache + // Restart ACME old.cancelACME() if err := srv.startACME(); err != nil {