2020-09-10 14:49:59 +02:00
|
|
|
package tlstunnel
|
2020-09-08 17:13:39 +02:00
|
|
|
|
|
|
|
import (
|
2020-09-09 14:08:20 +02:00
|
|
|
"context"
|
2020-09-08 18:24:16 +02:00
|
|
|
"crypto/tls"
|
2022-07-07 10:55:25 +02:00
|
|
|
"errors"
|
2020-09-08 17:13:39 +02:00
|
|
|
"fmt"
|
|
|
|
"io"
|
2020-09-08 18:24:16 +02:00
|
|
|
"log"
|
2020-09-08 17:13:39 +02:00
|
|
|
"net"
|
2020-09-12 19:43:16 +02:00
|
|
|
"strings"
|
2020-12-22 12:06:14 +01:00
|
|
|
"sync/atomic"
|
2021-02-18 17:49:52 +01:00
|
|
|
"time"
|
2020-09-08 18:24:16 +02:00
|
|
|
|
2020-10-19 16:44:46 +02:00
|
|
|
"git.sr.ht/~emersion/go-scfg"
|
2020-09-08 18:24:16 +02:00
|
|
|
"github.com/caddyserver/certmagic"
|
2020-09-09 14:52:41 +02:00
|
|
|
"github.com/pires/go-proxyproto"
|
2020-10-09 14:45:55 +02:00
|
|
|
"github.com/pires/go-proxyproto/tlvparse"
|
2020-09-08 17:13:39 +02:00
|
|
|
)
|
|
|
|
|
2021-02-18 18:16:10 +01:00
|
|
|
const tlsHandshakeTimeout = 20 * time.Second
|
2021-02-18 17:49:52 +01:00
|
|
|
|
2021-02-17 18:33:07 +01:00
|
|
|
type acmeCache struct {
|
2021-02-18 18:20:47 +01:00
|
|
|
config atomic.Value
|
2021-02-17 18:33:07 +01:00
|
|
|
cache *certmagic.Cache
|
|
|
|
}
|
|
|
|
|
|
|
|
func newACMECache() *acmeCache {
|
|
|
|
cache := &acmeCache{}
|
|
|
|
cache.cache = certmagic.NewCache(certmagic.CacheOptions{
|
|
|
|
GetConfigForCert: func(certmagic.Certificate) (*certmagic.Config, error) {
|
2021-02-18 18:20:47 +01:00
|
|
|
return cache.config.Load().(*certmagic.Config), nil
|
2021-02-17 18:33:07 +01:00
|
|
|
},
|
|
|
|
})
|
|
|
|
return cache
|
|
|
|
}
|
|
|
|
|
2020-09-08 17:13:39 +02:00
|
|
|
type Server struct {
|
2020-10-19 17:27:29 +02:00
|
|
|
Listeners map[string]*Listener // indexed by listening address
|
|
|
|
Frontends []*Frontend
|
2023-01-26 11:43:59 +01:00
|
|
|
Debug bool
|
2020-10-19 17:27:29 +02:00
|
|
|
|
|
|
|
ManagedNames []string
|
|
|
|
UnmanagedCerts []tls.Certificate
|
|
|
|
|
2022-07-07 10:49:10 +02:00
|
|
|
ACMEIssuer *certmagic.ACMEIssuer
|
|
|
|
ACMEConfig *certmagic.Config
|
2020-12-22 12:06:14 +01:00
|
|
|
|
2023-11-20 15:40:42 +01:00
|
|
|
acmeCache *acmeCache
|
|
|
|
cancelACME context.CancelFunc
|
|
|
|
unmanagedHashes []string
|
2020-09-08 18:24:16 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewServer() *Server {
|
2021-02-17 18:33:07 +01:00
|
|
|
// Make a copy of the defaults
|
|
|
|
acmeConfig := certmagic.Default
|
|
|
|
acmeManager := certmagic.DefaultACME
|
2020-09-08 18:24:16 +02:00
|
|
|
|
2021-02-17 18:33:07 +01:00
|
|
|
acmeManager.Agreed = true
|
2020-10-21 15:24:25 +02:00
|
|
|
// We're a TLS server, we don't speak HTTP
|
2021-02-17 18:33:07 +01:00
|
|
|
acmeManager.DisableHTTPChallenge = true
|
2020-09-08 18:24:16 +02:00
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
return &Server{
|
2022-07-07 10:49:10 +02:00
|
|
|
Listeners: make(map[string]*Listener),
|
|
|
|
ACMEIssuer: &acmeManager,
|
|
|
|
ACMEConfig: &acmeConfig,
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
2020-09-08 17:13:39 +02:00
|
|
|
}
|
|
|
|
|
2020-10-19 16:44:46 +02:00
|
|
|
func (srv *Server) Load(cfg scfg.Block) error {
|
2020-09-10 15:05:43 +02:00
|
|
|
return parseConfig(srv, cfg)
|
|
|
|
}
|
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
func (srv *Server) RegisterListener(addr string) *Listener {
|
|
|
|
// TODO: normalize addr with net.LookupPort
|
|
|
|
ln, ok := srv.Listeners[addr]
|
|
|
|
if !ok {
|
|
|
|
ln = newListener(srv, addr)
|
|
|
|
srv.Listeners[addr] = ln
|
|
|
|
}
|
|
|
|
return ln
|
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
func (srv *Server) startACME() error {
|
|
|
|
var ctx context.Context
|
|
|
|
ctx, srv.cancelACME = context.WithCancel(context.Background())
|
|
|
|
|
2021-02-17 18:33:07 +01:00
|
|
|
srv.ACMEConfig = certmagic.New(srv.acmeCache.cache, *srv.ACMEConfig)
|
2022-07-07 10:49:10 +02:00
|
|
|
srv.ACMEIssuer = certmagic.NewACMEIssuer(srv.ACMEConfig, *srv.ACMEIssuer)
|
2021-02-17 18:33:07 +01:00
|
|
|
|
2022-07-07 10:49:10 +02:00
|
|
|
srv.ACMEConfig.Issuers = []certmagic.Issuer{srv.ACMEIssuer}
|
2021-02-17 18:33:07 +01:00
|
|
|
|
2021-02-18 18:20:47 +01:00
|
|
|
srv.acmeCache.config.Store(srv.ACMEConfig)
|
2021-02-17 18:33:07 +01:00
|
|
|
|
2020-10-19 17:27:29 +02:00
|
|
|
for _, cert := range srv.UnmanagedCerts {
|
2023-11-20 15:40:42 +01:00
|
|
|
hash, err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(ctx, cert, nil)
|
2023-11-20 15:34:03 +01:00
|
|
|
if err != nil {
|
2021-02-18 16:02:45 +01:00
|
|
|
return fmt.Errorf("failed to cache unmanaged TLS certificate: %v", err)
|
2020-10-19 17:27:29 +02:00
|
|
|
}
|
2023-11-20 15:40:42 +01:00
|
|
|
srv.unmanagedHashes = append(srv.unmanagedHashes, hash)
|
2020-10-19 17:27:29 +02:00
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil {
|
2020-09-09 14:08:20 +02:00
|
|
|
return fmt.Errorf("failed to manage TLS certificates: %v", err)
|
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (srv *Server) Start() error {
|
2021-02-17 18:33:07 +01:00
|
|
|
srv.acmeCache = newACMECache()
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
if err := srv.startACME(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
for _, ln := range srv.Listeners {
|
|
|
|
if err := ln.Start(); err != nil {
|
2021-02-18 16:02:45 +01:00
|
|
|
return fmt.Errorf("failed to start listener: %v", err)
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
func (srv *Server) Stop() {
|
|
|
|
srv.cancelACME()
|
2022-02-03 10:42:06 +01:00
|
|
|
for addr, ln := range srv.Listeners {
|
|
|
|
if err := ln.Stop(); err != nil {
|
|
|
|
log.Printf("listener %q: failed to stop: %v", addr, err)
|
|
|
|
}
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
2021-02-17 18:45:14 +01:00
|
|
|
srv.acmeCache.cache.Stop()
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Replace starts the server but takes over existing listeners from an old
|
|
|
|
// Server instance. The old instance keeps running unchanged if Replace
|
|
|
|
// returns an error.
|
|
|
|
func (srv *Server) Replace(old *Server) error {
|
|
|
|
// Try to start new listeners
|
|
|
|
for addr, ln := range srv.Listeners {
|
|
|
|
if _, ok := old.Listeners[addr]; ok {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if err := ln.Start(); err != nil {
|
|
|
|
for _, ln2 := range srv.Listeners {
|
|
|
|
ln2.Stop()
|
|
|
|
}
|
2021-02-18 16:02:45 +01:00
|
|
|
return fmt.Errorf("failed to start listener: %v", err)
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-02-17 18:33:07 +01:00
|
|
|
// Steal the old server's ACME cache
|
|
|
|
srv.acmeCache = old.acmeCache
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
// Restart ACME
|
|
|
|
old.cancelACME()
|
|
|
|
if err := srv.startACME(); err != nil {
|
2021-02-18 16:02:45 +01:00
|
|
|
for _, ln := range srv.Listeners {
|
|
|
|
ln.Stop()
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
2021-02-18 16:02:45 +01:00
|
|
|
return fmt.Errorf("failed to start ACME: %v", err)
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Take over existing listeners and terminate old ones
|
|
|
|
for addr, oldLn := range old.Listeners {
|
|
|
|
if ln, ok := srv.Listeners[addr]; ok {
|
|
|
|
srv.Listeners[addr] = oldLn.UpdateFrom(ln)
|
|
|
|
} else {
|
2022-02-03 10:42:06 +01:00
|
|
|
if err := oldLn.Stop(); err != nil {
|
|
|
|
log.Printf("listener %q: failed to stop: %v", addr, err)
|
|
|
|
}
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-02-18 18:09:37 +01:00
|
|
|
// Cleanup managed certs which are no longer used
|
2023-11-20 15:36:04 +01:00
|
|
|
managed := make(map[string]struct{}, len(srv.ManagedNames))
|
2021-02-18 18:09:37 +01:00
|
|
|
for _, name := range srv.ManagedNames {
|
|
|
|
managed[name] = struct{}{}
|
|
|
|
}
|
2023-11-20 15:34:03 +01:00
|
|
|
removeManaged := make([]string, 0, len(old.ManagedNames))
|
2021-02-18 18:09:37 +01:00
|
|
|
for _, name := range old.ManagedNames {
|
|
|
|
if _, ok := managed[name]; !ok {
|
2023-11-20 15:34:03 +01:00
|
|
|
removeManaged = append(removeManaged, name)
|
2021-02-18 18:09:37 +01:00
|
|
|
}
|
|
|
|
}
|
2023-11-20 15:34:03 +01:00
|
|
|
srv.acmeCache.cache.RemoveManaged(removeManaged)
|
2021-02-18 18:09:37 +01:00
|
|
|
|
2023-11-20 15:40:42 +01:00
|
|
|
// Cleanup unmanaged certs which are no longer used
|
|
|
|
unmanaged := make(map[string]struct{}, len(srv.unmanagedHashes))
|
|
|
|
for _, hash := range srv.unmanagedHashes {
|
|
|
|
unmanaged[hash] = struct{}{}
|
|
|
|
}
|
|
|
|
removeUnmanaged := make([]string, 0, len(old.unmanagedHashes))
|
|
|
|
for _, hash := range old.unmanagedHashes {
|
|
|
|
if _, ok := unmanaged[hash]; !ok {
|
|
|
|
removeUnmanaged = append(removeUnmanaged, hash)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
srv.acmeCache.cache.Remove(removeUnmanaged)
|
2021-02-18 18:09:37 +01:00
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-01-27 11:04:36 +01:00
|
|
|
type clientError struct {
|
|
|
|
error
|
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
type listenerHandles struct {
|
2020-09-09 13:15:03 +02:00
|
|
|
Server *Server
|
|
|
|
Frontends map[string]*Frontend // indexed by server name
|
2020-09-08 17:13:39 +02:00
|
|
|
}
|
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
type Listener struct {
|
|
|
|
Address string
|
|
|
|
netLn net.Listener
|
|
|
|
atomic atomic.Value
|
|
|
|
}
|
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
func newListener(srv *Server, addr string) *Listener {
|
2020-12-22 12:06:14 +01:00
|
|
|
ln := &Listener{
|
|
|
|
Address: addr,
|
|
|
|
}
|
|
|
|
ln.atomic.Store(&listenerHandles{
|
2020-09-09 13:15:03 +02:00
|
|
|
Server: srv,
|
|
|
|
Frontends: make(map[string]*Frontend),
|
2020-12-22 12:06:14 +01:00
|
|
|
})
|
|
|
|
return ln
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
|
2020-12-22 12:06:14 +01:00
|
|
|
fes := ln.atomic.Load().(*listenerHandles).Frontends
|
|
|
|
if _, ok := fes[name]; ok {
|
2020-09-09 13:15:03 +02:00
|
|
|
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
|
|
|
|
}
|
2020-12-22 12:06:14 +01:00
|
|
|
fes[name] = fe
|
2020-09-09 13:15:03 +02:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *Listener) Start() error {
|
2020-12-22 12:06:14 +01:00
|
|
|
var err error
|
|
|
|
ln.netLn, err = net.Listen("tcp", ln.Address)
|
2020-09-09 13:15:03 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
log.Printf("listening on %q", ln.Address)
|
|
|
|
|
2022-02-03 10:36:08 +01:00
|
|
|
ln.netLn = &retryListener{Listener: ln.netLn}
|
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
go func() {
|
2020-12-22 12:06:14 +01:00
|
|
|
if err := ln.serve(); err != nil {
|
2020-09-09 13:15:03 +02:00
|
|
|
log.Fatalf("listener %q: %v", ln.Address, err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2022-02-03 10:42:06 +01:00
|
|
|
func (ln *Listener) Stop() error {
|
|
|
|
return ln.netLn.Close()
|
2020-12-22 12:06:14 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *Listener) UpdateFrom(new *Listener) *Listener {
|
|
|
|
ln.atomic.Store(new.atomic.Load())
|
|
|
|
return ln
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *Listener) serve() error {
|
2020-09-08 17:13:39 +02:00
|
|
|
for {
|
2020-12-22 12:06:14 +01:00
|
|
|
conn, err := ln.netLn.Accept()
|
2022-07-07 10:55:25 +02:00
|
|
|
if errors.Is(err, net.ErrClosed) {
|
2020-12-22 12:06:14 +01:00
|
|
|
// Listening socket has been closed by Stop()
|
|
|
|
return nil
|
|
|
|
} else if err != nil {
|
2020-09-08 17:13:39 +02:00
|
|
|
return fmt.Errorf("failed to accept connection: %v", err)
|
|
|
|
}
|
|
|
|
|
2020-09-08 18:24:16 +02:00
|
|
|
go func() {
|
2023-01-26 11:43:59 +01:00
|
|
|
err := ln.handle(conn)
|
2023-02-20 14:40:44 +01:00
|
|
|
if err == nil {
|
|
|
|
return
|
|
|
|
}
|
2023-01-26 11:43:59 +01:00
|
|
|
srv := ln.atomic.Load().(*listenerHandles).Server
|
2023-01-27 11:04:36 +01:00
|
|
|
var clientErr clientError
|
|
|
|
if !errors.As(err, &clientErr) || srv.Debug {
|
2023-01-26 11:43:59 +01:00
|
|
|
log.Printf("listener %q: connection %q: %v", ln.Address, conn.RemoteAddr(), err)
|
2020-09-08 18:24:16 +02:00
|
|
|
}
|
|
|
|
}()
|
2020-09-08 17:13:39 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-09-09 13:15:03 +02:00
|
|
|
func (ln *Listener) handle(conn net.Conn) error {
|
|
|
|
defer conn.Close()
|
2020-12-22 12:06:14 +01:00
|
|
|
srv := ln.atomic.Load().(*listenerHandles).Server
|
2020-09-09 13:15:03 +02:00
|
|
|
|
2020-12-22 12:06:14 +01:00
|
|
|
tlsConfig := srv.ACMEConfig.TLSConfig()
|
2020-10-19 10:53:36 +02:00
|
|
|
getConfigForClient := tlsConfig.GetConfigForClient
|
|
|
|
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
|
|
// Call previous GetConfigForClient function, if any
|
|
|
|
var tlsConfig *tls.Config
|
|
|
|
if getConfigForClient != nil {
|
|
|
|
var err error
|
|
|
|
tlsConfig, err = getConfigForClient(hello)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
} else {
|
2020-12-22 12:06:14 +01:00
|
|
|
tlsConfig = srv.ACMEConfig.TLSConfig()
|
2020-10-19 10:53:36 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
fe, err := ln.matchFrontend(hello.ServerName)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2021-02-18 16:05:45 +01:00
|
|
|
tlsConfig.NextProtos = append(tlsConfig.NextProtos, fe.Protocols...)
|
2020-10-19 10:53:36 +02:00
|
|
|
return tlsConfig, nil
|
|
|
|
}
|
|
|
|
tlsConn := tls.Server(conn, tlsConfig)
|
2021-02-18 17:49:52 +01:00
|
|
|
|
|
|
|
if err := tlsConn.SetDeadline(time.Now().Add(tlsHandshakeTimeout)); err != nil {
|
|
|
|
return fmt.Errorf("failed to set TLS handshake timeout: %v", err)
|
|
|
|
}
|
2022-02-03 10:22:53 +01:00
|
|
|
if err := tlsConn.Handshake(); err == io.EOF {
|
|
|
|
return nil
|
|
|
|
} else if err != nil {
|
2023-01-27 11:04:36 +01:00
|
|
|
return clientError{fmt.Errorf("TLS handshake failed: %v", err)}
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
2021-02-18 17:49:52 +01:00
|
|
|
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
|
|
|
return fmt.Errorf("failed to reset TLS handshake timeout: %v", err)
|
|
|
|
}
|
|
|
|
// TODO: allow setting custom downstream timeouts
|
2020-09-09 13:15:03 +02:00
|
|
|
|
|
|
|
tlsState := tlsConn.ConnectionState()
|
2020-10-19 10:53:36 +02:00
|
|
|
fe, err := ln.matchFrontend(tlsState.ServerName)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2020-09-09 13:15:03 +02:00
|
|
|
|
2020-10-19 10:53:36 +02:00
|
|
|
return fe.handle(tlsConn, &tlsState)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
2020-12-22 12:06:14 +01:00
|
|
|
fes := ln.atomic.Load().(*listenerHandles).Frontends
|
|
|
|
|
|
|
|
fe, ok := fes[serverName]
|
2020-09-12 19:43:16 +02:00
|
|
|
if !ok {
|
2020-10-19 10:53:36 +02:00
|
|
|
// Match wildcard certificates, allowing only a single, non-partial
|
|
|
|
// wildcard, in the left-most label
|
|
|
|
i := strings.IndexByte(serverName, '.')
|
|
|
|
// Don't allow wildcards with only a TLD (e.g. *.com)
|
|
|
|
if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
|
2020-12-22 12:06:14 +01:00
|
|
|
fe, ok = fes["*"+serverName[i:]]
|
2020-09-12 19:43:16 +02:00
|
|
|
}
|
|
|
|
}
|
2020-09-09 13:15:03 +02:00
|
|
|
if !ok {
|
2020-12-22 12:06:14 +01:00
|
|
|
fe, ok = fes[""]
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
if !ok {
|
2020-10-19 10:53:36 +02:00
|
|
|
return nil, fmt.Errorf("can't find frontend for server name %q", serverName)
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
|
2020-10-19 10:53:36 +02:00
|
|
|
return fe, nil
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type Frontend struct {
|
2020-10-19 10:53:36 +02:00
|
|
|
Backend Backend
|
|
|
|
Protocols []string
|
2020-09-09 13:15:03 +02:00
|
|
|
}
|
|
|
|
|
2020-10-09 12:21:19 +02:00
|
|
|
func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) error {
|
2020-09-08 17:13:39 +02:00
|
|
|
defer downstream.Close()
|
|
|
|
|
2021-02-18 17:49:52 +01:00
|
|
|
// TODO: setup upstream timeouts
|
|
|
|
|
2020-09-08 17:13:39 +02:00
|
|
|
be := &fe.Backend
|
|
|
|
upstream, err := net.Dial(be.Network, be.Address)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to dial backend: %v", err)
|
|
|
|
}
|
2020-10-31 10:34:02 +01:00
|
|
|
if be.TLSConfig != nil {
|
|
|
|
upstream = tls.Client(upstream, be.TLSConfig)
|
|
|
|
}
|
2020-09-08 17:13:39 +02:00
|
|
|
defer upstream.Close()
|
|
|
|
|
2020-09-09 14:52:41 +02:00
|
|
|
if be.Proxy {
|
2023-02-09 15:19:29 +01:00
|
|
|
h := proxyproto.HeaderProxyFromAddrs(byte(be.ProxyVersion), downstream.RemoteAddr(), downstream.LocalAddr())
|
2020-10-09 12:21:19 +02:00
|
|
|
|
|
|
|
var tlvs []proxyproto.TLV
|
|
|
|
if tlsState.ServerName != "" {
|
|
|
|
tlvs = append(tlvs, authorityTLV(tlsState.ServerName))
|
|
|
|
}
|
2020-10-19 10:53:36 +02:00
|
|
|
if tlsState.NegotiatedProtocol != "" {
|
|
|
|
tlvs = append(tlvs, alpnTLV(tlsState.NegotiatedProtocol))
|
|
|
|
}
|
2020-10-09 14:45:55 +02:00
|
|
|
if tlv, err := sslTLV(tlsState); err != nil {
|
|
|
|
return fmt.Errorf("failed to set PROXY protocol header SSL TLV: %v", err)
|
|
|
|
} else {
|
|
|
|
tlvs = append(tlvs, tlv)
|
|
|
|
}
|
2020-10-09 12:21:19 +02:00
|
|
|
if err := h.SetTLVs(tlvs); err != nil {
|
|
|
|
return fmt.Errorf("failed to set PROXY protocol header TLVs: %v", err)
|
|
|
|
}
|
|
|
|
|
2020-09-09 14:52:41 +02:00
|
|
|
if _, err := h.WriteTo(upstream); err != nil {
|
|
|
|
return fmt.Errorf("failed to write PROXY protocol header: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-02-18 16:02:45 +01:00
|
|
|
if err := duplexCopy(upstream, downstream); err != nil {
|
2023-01-27 11:04:36 +01:00
|
|
|
return clientError{fmt.Errorf("failed to copy bytes: %v", err)}
|
2021-02-18 16:02:45 +01:00
|
|
|
}
|
|
|
|
return nil
|
2020-09-08 17:13:39 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type Backend struct {
|
2023-02-09 15:19:29 +01:00
|
|
|
Network string
|
|
|
|
Address string
|
|
|
|
Proxy bool
|
|
|
|
ProxyVersion int
|
|
|
|
TLSConfig *tls.Config // nil if no TLS
|
2020-09-08 17:13:39 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func duplexCopy(a, b io.ReadWriter) error {
|
|
|
|
done := make(chan error, 2)
|
|
|
|
go func() {
|
|
|
|
_, err := io.Copy(a, b)
|
|
|
|
done <- err
|
|
|
|
}()
|
|
|
|
go func() {
|
|
|
|
_, err := io.Copy(b, a)
|
|
|
|
done <- err
|
|
|
|
}()
|
|
|
|
return <-done
|
|
|
|
}
|
2020-10-09 12:21:19 +02:00
|
|
|
|
2020-10-29 14:21:03 +01:00
|
|
|
func authorityTLV(name string) proxyproto.TLV {
|
2020-10-09 12:21:19 +02:00
|
|
|
return proxyproto.TLV{
|
2020-10-29 14:21:03 +01:00
|
|
|
Type: proxyproto.PP2_TYPE_AUTHORITY,
|
|
|
|
Value: []byte(name),
|
2020-10-09 12:21:19 +02:00
|
|
|
}
|
|
|
|
}
|
2020-10-09 14:45:55 +02:00
|
|
|
|
2020-10-19 10:53:36 +02:00
|
|
|
func alpnTLV(proto string) proxyproto.TLV {
|
|
|
|
return proxyproto.TLV{
|
2020-12-08 17:03:58 +01:00
|
|
|
Type: proxyproto.PP2_TYPE_ALPN,
|
2020-10-19 10:53:36 +02:00
|
|
|
Value: []byte(proto),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-10-09 14:45:55 +02:00
|
|
|
func sslTLV(state *tls.ConnectionState) (proxyproto.TLV, error) {
|
|
|
|
pp2ssl := tlvparse.PP2SSL{
|
|
|
|
Client: tlvparse.PP2_BITFIELD_CLIENT_SSL, // all of our connections are TLS
|
|
|
|
Verify: 1, // we haven't checked the client cert
|
|
|
|
}
|
|
|
|
|
|
|
|
var version string
|
|
|
|
switch state.Version {
|
|
|
|
case tls.VersionTLS10:
|
|
|
|
version = "TLSv1.0"
|
|
|
|
case tls.VersionTLS11:
|
|
|
|
version = "TLSv1.1"
|
|
|
|
case tls.VersionTLS12:
|
|
|
|
version = "TLSv1.2"
|
|
|
|
case tls.VersionTLS13:
|
|
|
|
version = "TLSv1.3"
|
|
|
|
}
|
|
|
|
if version != "" {
|
2020-10-29 14:21:03 +01:00
|
|
|
versionTLV := proxyproto.TLV{
|
|
|
|
Type: proxyproto.PP2_SUBTYPE_SSL_VERSION,
|
|
|
|
Value: []byte(version),
|
|
|
|
}
|
2020-10-09 14:45:55 +02:00
|
|
|
pp2ssl.TLV = append(pp2ssl.TLV, versionTLV)
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: add PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG
|
|
|
|
// TODO: check client-provided cert, if any
|
|
|
|
|
|
|
|
return pp2ssl.Marshal()
|
|
|
|
}
|
2022-02-03 10:36:08 +01:00
|
|
|
|
|
|
|
type retryListener struct {
|
|
|
|
net.Listener
|
|
|
|
|
|
|
|
delay time.Duration
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ln *retryListener) Accept() (net.Conn, error) {
|
|
|
|
for {
|
|
|
|
conn, err := ln.Listener.Accept()
|
|
|
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
|
|
|
if ln.delay == 0 {
|
|
|
|
ln.delay = 5 * time.Millisecond
|
|
|
|
} else {
|
|
|
|
ln.delay *= 2
|
|
|
|
}
|
|
|
|
if max := 1 * time.Second; ln.delay > max {
|
|
|
|
ln.delay = max
|
|
|
|
}
|
|
|
|
log.Printf("listener %q: accept error (retrying in %v): %v", ln.Addr(), ln.delay, err)
|
|
|
|
time.Sleep(ln.delay)
|
|
|
|
} else {
|
|
|
|
ln.delay = 0
|
|
|
|
return conn, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|