1
1
mirror of https://git.sr.ht/~emersion/tlstunnel synced 2024-11-19 15:53:50 +01:00
tlstunnel/server.go

461 lines
10 KiB
Go
Raw Normal View History

package tlstunnel
2020-09-08 17:13:39 +02:00
import (
2020-09-09 14:08:20 +02:00
"context"
2020-09-08 18:24:16 +02:00
"crypto/tls"
2020-09-08 17:13:39 +02:00
"fmt"
"io"
2020-09-08 18:24:16 +02:00
"log"
2020-09-08 17:13:39 +02:00
"net"
"strings"
"sync/atomic"
2020-09-08 18:24:16 +02:00
"git.sr.ht/~emersion/go-scfg"
2020-09-08 18:24:16 +02:00
"github.com/caddyserver/certmagic"
2020-09-09 14:52:41 +02:00
"github.com/pires/go-proxyproto"
2020-10-09 14:45:55 +02:00
"github.com/pires/go-proxyproto/tlvparse"
2020-09-08 17:13:39 +02:00
)
type acmeCache struct {
config *certmagic.Config
cache *certmagic.Cache
}
func newACMECache() *acmeCache {
cache := &acmeCache{}
cache.cache = certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(certmagic.Certificate) (*certmagic.Config, error) {
return cache.config, nil
},
})
return cache
}
2020-09-08 17:13:39 +02:00
type Server struct {
2020-10-19 17:27:29 +02:00
Listeners map[string]*Listener // indexed by listening address
Frontends []*Frontend
ManagedNames []string
UnmanagedCerts []tls.Certificate
ACMEManager *certmagic.ACMEManager
ACMEConfig *certmagic.Config
acmeCache *acmeCache
cancelACME context.CancelFunc
2020-09-08 18:24:16 +02:00
}
func NewServer() *Server {
// Make a copy of the defaults
acmeConfig := certmagic.Default
acmeManager := certmagic.DefaultACME
2020-09-08 18:24:16 +02:00
acmeManager.Agreed = true
// We're a TLS server, we don't speak HTTP
acmeManager.DisableHTTPChallenge = true
2020-09-08 18:24:16 +02:00
return &Server{
2020-09-09 14:08:20 +02:00
Listeners: make(map[string]*Listener),
ACMEManager: &acmeManager,
ACMEConfig: &acmeConfig,
}
2020-09-08 17:13:39 +02:00
}
func (srv *Server) Load(cfg scfg.Block) error {
return parseConfig(srv, cfg)
}
func (srv *Server) RegisterListener(addr string) *Listener {
// TODO: normalize addr with net.LookupPort
ln, ok := srv.Listeners[addr]
if !ok {
ln = newListener(srv, addr)
srv.Listeners[addr] = ln
}
return ln
}
func (srv *Server) startACME() error {
var ctx context.Context
ctx, srv.cancelACME = context.WithCancel(context.Background())
srv.ACMEConfig = certmagic.New(srv.acmeCache.cache, *srv.ACMEConfig)
srv.ACMEManager = certmagic.NewACMEManager(srv.ACMEConfig, *srv.ACMEManager)
srv.ACMEConfig.Issuer = srv.ACMEManager
srv.ACMEConfig.Revoker = srv.ACMEManager
srv.acmeCache.config = srv.ACMEConfig
2020-10-19 17:27:29 +02:00
for _, cert := range srv.UnmanagedCerts {
if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
2021-02-18 16:02:45 +01:00
return fmt.Errorf("failed to cache unmanaged TLS certificate: %v", err)
2020-10-19 17:27:29 +02:00
}
}
if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil {
2020-09-09 14:08:20 +02:00
return fmt.Errorf("failed to manage TLS certificates: %v", err)
}
return nil
}
func (srv *Server) Start() error {
srv.acmeCache = newACMECache()
if err := srv.startACME(); err != nil {
return err
}
for _, ln := range srv.Listeners {
if err := ln.Start(); err != nil {
2021-02-18 16:02:45 +01:00
return fmt.Errorf("failed to start listener: %v", err)
}
}
return nil
}
func (srv *Server) Stop() {
srv.cancelACME()
// TODO: clean cached unmanaged certs
for _, ln := range srv.Listeners {
ln.Stop()
}
2021-02-17 18:45:14 +01:00
srv.acmeCache.cache.Stop()
}
// Replace starts the server but takes over existing listeners from an old
// Server instance. The old instance keeps running unchanged if Replace
// returns an error.
func (srv *Server) Replace(old *Server) error {
// Try to start new listeners
for addr, ln := range srv.Listeners {
if _, ok := old.Listeners[addr]; ok {
continue
}
if err := ln.Start(); err != nil {
for _, ln2 := range srv.Listeners {
ln2.Stop()
}
2021-02-18 16:02:45 +01:00
return fmt.Errorf("failed to start listener: %v", err)
}
}
// Steal the old server's ACME cache
srv.acmeCache = old.acmeCache
// Restart ACME
old.cancelACME()
if err := srv.startACME(); err != nil {
2021-02-18 16:02:45 +01:00
for _, ln := range srv.Listeners {
ln.Stop()
}
2021-02-18 16:02:45 +01:00
return fmt.Errorf("failed to start ACME: %v", err)
}
// TODO: clean cached unmanaged certs
// Take over existing listeners and terminate old ones
for addr, oldLn := range old.Listeners {
if ln, ok := srv.Listeners[addr]; ok {
srv.Listeners[addr] = oldLn.UpdateFrom(ln)
} else {
oldLn.Stop()
}
}
return nil
}
type listenerHandles struct {
Server *Server
Frontends map[string]*Frontend // indexed by server name
2020-09-08 17:13:39 +02:00
}
type Listener struct {
Address string
netLn net.Listener
atomic atomic.Value
}
func newListener(srv *Server, addr string) *Listener {
ln := &Listener{
Address: addr,
}
ln.atomic.Store(&listenerHandles{
Server: srv,
Frontends: make(map[string]*Frontend),
})
return ln
}
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
fes := ln.atomic.Load().(*listenerHandles).Frontends
if _, ok := fes[name]; ok {
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
}
fes[name] = fe
return nil
}
func (ln *Listener) Start() error {
var err error
ln.netLn, err = net.Listen("tcp", ln.Address)
if err != nil {
return err
}
log.Printf("listening on %q", ln.Address)
go func() {
if err := ln.serve(); err != nil {
log.Fatalf("listener %q: %v", ln.Address, err)
}
}()
return nil
}
func (ln *Listener) Stop() {
ln.netLn.Close()
}
func (ln *Listener) UpdateFrom(new *Listener) *Listener {
ln.atomic.Store(new.atomic.Load())
return ln
}
func (ln *Listener) serve() error {
2020-09-08 17:13:39 +02:00
for {
conn, err := ln.netLn.Accept()
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
// Listening socket has been closed by Stop()
return nil
} else if err != nil {
2020-09-08 17:13:39 +02:00
return fmt.Errorf("failed to accept connection: %v", err)
}
2020-09-08 18:24:16 +02:00
go func() {
if err := ln.handle(conn); err != nil {
log.Printf("listener %q: %v", ln.Address, err)
2020-09-08 18:24:16 +02:00
}
}()
2020-09-08 17:13:39 +02:00
}
}
2021-02-18 17:05:53 +01:00
type duplexCloser interface {
CloseRead() error
CloseWrite() error
}
type tlsCloser struct {
closer duplexCloser
tls *tls.Conn
}
func (tc tlsCloser) CloseWrite() error {
tlsErr := tc.tls.CloseWrite()
err := tc.closer.CloseWrite()
if tlsErr != nil {
return tlsErr
}
return err
}
func (tc tlsCloser) CloseRead() error {
return tc.closer.CloseRead()
}
func (ln *Listener) handle(conn net.Conn) error {
defer conn.Close()
srv := ln.atomic.Load().(*listenerHandles).Server
// TODO: setup timeouts
tlsConfig := srv.ACMEConfig.TLSConfig()
getConfigForClient := tlsConfig.GetConfigForClient
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
// Call previous GetConfigForClient function, if any
var tlsConfig *tls.Config
if getConfigForClient != nil {
var err error
tlsConfig, err = getConfigForClient(hello)
if err != nil {
return nil, err
}
} else {
tlsConfig = srv.ACMEConfig.TLSConfig()
}
fe, err := ln.matchFrontend(hello.ServerName)
if err != nil {
return nil, err
}
tlsConfig.NextProtos = append(tlsConfig.NextProtos, fe.Protocols...)
return tlsConfig, nil
}
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
2021-02-18 16:02:45 +01:00
return fmt.Errorf("TLS handshake failed: %v", err)
}
tlsState := tlsConn.ConnectionState()
fe, err := ln.matchFrontend(tlsState.ServerName)
if err != nil {
return err
}
2021-02-18 17:05:53 +01:00
closer := tlsCloser{
closer: conn.(duplexCloser),
tls: tlsConn,
}
return fe.handle(tlsConn, closer, &tlsState)
}
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
fes := ln.atomic.Load().(*listenerHandles).Frontends
fe, ok := fes[serverName]
if !ok {
// Match wildcard certificates, allowing only a single, non-partial
// wildcard, in the left-most label
i := strings.IndexByte(serverName, '.')
// Don't allow wildcards with only a TLD (e.g. *.com)
if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
fe, ok = fes["*"+serverName[i:]]
}
}
if !ok {
fe, ok = fes[""]
}
if !ok {
return nil, fmt.Errorf("can't find frontend for server name %q", serverName)
}
return fe, nil
}
type Frontend struct {
Backend Backend
Protocols []string
}
2021-02-18 17:05:53 +01:00
func (fe *Frontend) handle(downstream net.Conn, downstreamCloser duplexCloser, tlsState *tls.ConnectionState) error {
2020-09-08 17:13:39 +02:00
defer downstream.Close()
be := &fe.Backend
upstream, err := net.Dial(be.Network, be.Address)
if err != nil {
return fmt.Errorf("failed to dial backend: %v", err)
}
2021-02-18 17:05:53 +01:00
upstreamCloser := upstream.(duplexCloser)
if be.TLSConfig != nil {
2021-02-18 17:05:53 +01:00
upstreamTLS := tls.Client(upstream, be.TLSConfig)
upstream = upstreamTLS
upstreamCloser = tlsCloser{
closer: upstreamCloser,
tls: upstreamTLS,
}
}
2020-09-08 17:13:39 +02:00
defer upstream.Close()
2020-09-09 14:52:41 +02:00
if be.Proxy {
h := proxyproto.HeaderProxyFromAddrs(2, downstream.RemoteAddr(), downstream.LocalAddr())
var tlvs []proxyproto.TLV
if tlsState.ServerName != "" {
tlvs = append(tlvs, authorityTLV(tlsState.ServerName))
}
if tlsState.NegotiatedProtocol != "" {
tlvs = append(tlvs, alpnTLV(tlsState.NegotiatedProtocol))
}
2020-10-09 14:45:55 +02:00
if tlv, err := sslTLV(tlsState); err != nil {
return fmt.Errorf("failed to set PROXY protocol header SSL TLV: %v", err)
} else {
tlvs = append(tlvs, tlv)
}
if err := h.SetTLVs(tlvs); err != nil {
return fmt.Errorf("failed to set PROXY protocol header TLVs: %v", err)
}
2020-09-09 14:52:41 +02:00
if _, err := h.WriteTo(upstream); err != nil {
return fmt.Errorf("failed to write PROXY protocol header: %v", err)
}
}
2021-02-18 17:05:53 +01:00
done := make(chan error, 2)
go func() {
_, copyErr := io.Copy(upstream, downstream)
upstreamCloser.CloseWrite()
if copyErr != nil {
done <- fmt.Errorf("failed to copy from downstream to upstream: %v", copyErr)
} else {
done <- nil
}
}()
go func() {
_, copyErr := io.Copy(downstream, upstream)
downstreamCloser.CloseWrite()
if copyErr != nil {
done <- fmt.Errorf("failed to copy from upstream to downstream: %v", copyErr)
} else {
done <- nil
}
}()
if err := <-done; err != nil {
return err
2021-02-18 16:02:45 +01:00
}
2021-02-18 17:05:53 +01:00
return <-done
2020-09-08 17:13:39 +02:00
}
type Backend struct {
Network string
Address string
Proxy bool
TLSConfig *tls.Config // nil if no TLS
2020-09-08 17:13:39 +02:00
}
2020-10-29 14:21:03 +01:00
func authorityTLV(name string) proxyproto.TLV {
return proxyproto.TLV{
2020-10-29 14:21:03 +01:00
Type: proxyproto.PP2_TYPE_AUTHORITY,
Value: []byte(name),
}
}
2020-10-09 14:45:55 +02:00
func alpnTLV(proto string) proxyproto.TLV {
return proxyproto.TLV{
2020-12-08 17:03:58 +01:00
Type: proxyproto.PP2_TYPE_ALPN,
Value: []byte(proto),
}
}
2020-10-09 14:45:55 +02:00
func sslTLV(state *tls.ConnectionState) (proxyproto.TLV, error) {
pp2ssl := tlvparse.PP2SSL{
Client: tlvparse.PP2_BITFIELD_CLIENT_SSL, // all of our connections are TLS
Verify: 1, // we haven't checked the client cert
}
var version string
switch state.Version {
case tls.VersionTLS10:
version = "TLSv1.0"
case tls.VersionTLS11:
version = "TLSv1.1"
case tls.VersionTLS12:
version = "TLSv1.2"
case tls.VersionTLS13:
version = "TLSv1.3"
}
if version != "" {
2020-10-29 14:21:03 +01:00
versionTLV := proxyproto.TLV{
Type: proxyproto.PP2_SUBTYPE_SSL_VERSION,
Value: []byte(version),
}
2020-10-09 14:45:55 +02:00
pp2ssl.TLV = append(pp2ssl.TLV, versionTLV)
}
// TODO: add PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG
// TODO: check client-provided cert, if any
return pp2ssl.Marshal()
}