diff --git a/main.go b/main.go index 867930a..d1ea285 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,10 @@ func main() { } } + if err := srv.Start(); err != nil { + log.Fatal(err) + } + select {} } @@ -52,21 +56,19 @@ func parseFrontend(srv *Server, d *Directive) error { } // TODO: come up with something more robust + var name string if host != "localhost" && net.ParseIP(host) == nil { + name = host listenNames = append(listenNames, host) host = "" } - ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) - if err != nil { - return fmt.Errorf("failed to listen on %q: %v", listenAddr, err) - } + addr := net.JoinHostPort(host, port) - go func() { - if err := frontend.Serve(ln); err != nil { - log.Fatalf("failed to serve: %v", err) - } - }() + ln := srv.RegisterListener(addr) + if err := ln.RegisterFrontend(name, frontend); err != nil { + return err + } } if err := srv.certmagic.ManageAsync(context.Background(), listenNames); err != nil { diff --git a/server.go b/server.go index d9c80ec..c930ed2 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( ) type Server struct { + Listeners map[string]*Listener // indexed by listening address Frontends []*Frontend certmagic *certmagic.Config } @@ -19,6 +20,7 @@ func NewServer() *Server { cfg := certmagic.NewDefault() acme := certmagic.DefaultACME + // TODO: use production CA acme.CA = certmagic.LetsEncryptStagingCA acme.Agreed = true // TODO: enable HTTP challenge by peeking incoming requests on port 80 @@ -27,7 +29,104 @@ func NewServer() *Server { cfg.Issuer = mgr cfg.Revoker = mgr - return &Server{certmagic: cfg} + return &Server{ + Listeners: make(map[string]*Listener), + certmagic: cfg, + } +} + +func (srv *Server) RegisterListener(addr string) *Listener { + // TODO: normalize addr with net.LookupPort + ln, ok := srv.Listeners[addr] + if !ok { + ln = newListener(srv, addr) + srv.Listeners[addr] = ln + } + return ln +} + +func (srv *Server) Start() error { + for _, ln := range srv.Listeners { + if err := ln.Start(); err != nil { + return err + } + } + return nil +} + +type Listener struct { + Address string + Server *Server + Frontends map[string]*Frontend // indexed by server name +} + +func newListener(srv *Server, addr string) *Listener { + return &Listener{ + Address: addr, + Server: srv, + Frontends: make(map[string]*Frontend), + } +} + +func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error { + if _, ok := ln.Frontends[name]; ok { + return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name) + } + ln.Frontends[name] = fe + return nil +} + +func (ln *Listener) Start() error { + netLn, err := net.Listen("tcp", ln.Address) + if err != nil { + return err + } + log.Printf("listening on %q", ln.Address) + + go func() { + if err := ln.serve(netLn); err != nil { + log.Fatalf("listener %q: %v", ln.Address, err) + } + }() + + return nil +} + +func (ln *Listener) serve(netLn net.Listener) error { + for { + conn, err := netLn.Accept() + if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + go func() { + if err := ln.handle(conn); err != nil { + log.Printf("listener %q: %v", ln.Address, err) + } + }() + } +} + +func (ln *Listener) handle(conn net.Conn) error { + defer conn.Close() + + // TODO: setup timeouts + tlsConn := tls.Server(conn, ln.Server.certmagic.TLSConfig()) + if err := tlsConn.Handshake(); err != nil { + return err + } + + tlsState := tlsConn.ConnectionState() + + fe, ok := ln.Frontends[tlsState.ServerName] + if !ok { + fe, ok = ln.Frontends[""] + } + if !ok { + return fmt.Errorf("can't find frontend for server name %q", tlsState.ServerName) + } + + return fe.handle(tlsConn) } type Frontend struct { @@ -35,23 +134,6 @@ type Frontend struct { Backend Backend } -func (fe *Frontend) Serve(ln net.Listener) error { - for { - conn, err := ln.Accept() - if err != nil { - return fmt.Errorf("failed to accept connection: %v", err) - } - - conn = tls.Server(conn, fe.Server.certmagic.TLSConfig()) - - go func() { - if err := fe.handle(conn); err != nil { - log.Printf("error handling connection: %v", err) - } - }() - } -} - func (fe *Frontend) handle(downstream net.Conn) error { defer downstream.Close()