1
1
mirror of https://git.sr.ht/~emersion/tlstunnel synced 2024-11-19 15:53:50 +01:00

Allow to route to different backend depending on SNI

This commit is contained in:
Simon Ser 2020-09-09 13:15:03 +02:00
parent 728c5fcf17
commit 758cac1f77
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 111 additions and 27 deletions

18
main.go

@ -23,6 +23,10 @@ func main() {
} }
} }
if err := srv.Start(); err != nil {
log.Fatal(err)
}
select {} select {}
} }
@ -52,21 +56,19 @@ func parseFrontend(srv *Server, d *Directive) error {
} }
// TODO: come up with something more robust // TODO: come up with something more robust
var name string
if host != "localhost" && net.ParseIP(host) == nil { if host != "localhost" && net.ParseIP(host) == nil {
name = host
listenNames = append(listenNames, host) listenNames = append(listenNames, host)
host = "" host = ""
} }
ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) addr := net.JoinHostPort(host, port)
if err != nil {
return fmt.Errorf("failed to listen on %q: %v", listenAddr, err)
}
go func() { ln := srv.RegisterListener(addr)
if err := frontend.Serve(ln); err != nil { if err := ln.RegisterFrontend(name, frontend); err != nil {
log.Fatalf("failed to serve: %v", err) return err
} }
}()
} }
if err := srv.certmagic.ManageAsync(context.Background(), listenNames); err != nil { if err := srv.certmagic.ManageAsync(context.Background(), listenNames); err != nil {

118
server.go

@ -11,6 +11,7 @@ import (
) )
type Server struct { type Server struct {
Listeners map[string]*Listener // indexed by listening address
Frontends []*Frontend Frontends []*Frontend
certmagic *certmagic.Config certmagic *certmagic.Config
} }
@ -19,6 +20,7 @@ func NewServer() *Server {
cfg := certmagic.NewDefault() cfg := certmagic.NewDefault()
acme := certmagic.DefaultACME acme := certmagic.DefaultACME
// TODO: use production CA
acme.CA = certmagic.LetsEncryptStagingCA acme.CA = certmagic.LetsEncryptStagingCA
acme.Agreed = true acme.Agreed = true
// TODO: enable HTTP challenge by peeking incoming requests on port 80 // TODO: enable HTTP challenge by peeking incoming requests on port 80
@ -27,7 +29,104 @@ func NewServer() *Server {
cfg.Issuer = mgr cfg.Issuer = mgr
cfg.Revoker = mgr cfg.Revoker = mgr
return &Server{certmagic: cfg} return &Server{
Listeners: make(map[string]*Listener),
certmagic: cfg,
}
}
func (srv *Server) RegisterListener(addr string) *Listener {
// TODO: normalize addr with net.LookupPort
ln, ok := srv.Listeners[addr]
if !ok {
ln = newListener(srv, addr)
srv.Listeners[addr] = ln
}
return ln
}
func (srv *Server) Start() error {
for _, ln := range srv.Listeners {
if err := ln.Start(); err != nil {
return err
}
}
return nil
}
type Listener struct {
Address string
Server *Server
Frontends map[string]*Frontend // indexed by server name
}
func newListener(srv *Server, addr string) *Listener {
return &Listener{
Address: addr,
Server: srv,
Frontends: make(map[string]*Frontend),
}
}
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
if _, ok := ln.Frontends[name]; ok {
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
}
ln.Frontends[name] = fe
return nil
}
func (ln *Listener) Start() error {
netLn, err := net.Listen("tcp", ln.Address)
if err != nil {
return err
}
log.Printf("listening on %q", ln.Address)
go func() {
if err := ln.serve(netLn); err != nil {
log.Fatalf("listener %q: %v", ln.Address, err)
}
}()
return nil
}
func (ln *Listener) serve(netLn net.Listener) error {
for {
conn, err := netLn.Accept()
if err != nil {
return fmt.Errorf("failed to accept connection: %v", err)
}
go func() {
if err := ln.handle(conn); err != nil {
log.Printf("listener %q: %v", ln.Address, err)
}
}()
}
}
func (ln *Listener) handle(conn net.Conn) error {
defer conn.Close()
// TODO: setup timeouts
tlsConn := tls.Server(conn, ln.Server.certmagic.TLSConfig())
if err := tlsConn.Handshake(); err != nil {
return err
}
tlsState := tlsConn.ConnectionState()
fe, ok := ln.Frontends[tlsState.ServerName]
if !ok {
fe, ok = ln.Frontends[""]
}
if !ok {
return fmt.Errorf("can't find frontend for server name %q", tlsState.ServerName)
}
return fe.handle(tlsConn)
} }
type Frontend struct { type Frontend struct {
@ -35,23 +134,6 @@ type Frontend struct {
Backend Backend Backend Backend
} }
func (fe *Frontend) Serve(ln net.Listener) error {
for {
conn, err := ln.Accept()
if err != nil {
return fmt.Errorf("failed to accept connection: %v", err)
}
conn = tls.Server(conn, fe.Server.certmagic.TLSConfig())
go func() {
if err := fe.handle(conn); err != nil {
log.Printf("error handling connection: %v", err)
}
}()
}
}
func (fe *Frontend) handle(downstream net.Conn) error { func (fe *Frontend) handle(downstream net.Conn) error {
defer downstream.Close() defer downstream.Close()