diff --git a/server.go b/server.go index 54043bb..ce54f88 100644 --- a/server.go +++ b/server.go @@ -241,6 +241,29 @@ func (ln *Listener) serve() error { } } +type duplexCloser interface { + CloseRead() error + CloseWrite() error +} + +type tlsCloser struct { + closer duplexCloser + tls *tls.Conn +} + +func (tc tlsCloser) CloseWrite() error { + tlsErr := tc.tls.CloseWrite() + err := tc.closer.CloseWrite() + if tlsErr != nil { + return tlsErr + } + return err +} + +func (tc tlsCloser) CloseRead() error { + return tc.closer.CloseRead() +} + func (ln *Listener) handle(conn net.Conn) error { defer conn.Close() srv := ln.atomic.Load().(*listenerHandles).Server @@ -280,7 +303,11 @@ func (ln *Listener) handle(conn net.Conn) error { return err } - return fe.handle(tlsConn, &tlsState) + closer := tlsCloser{ + closer: conn.(duplexCloser), + tls: tlsConn, + } + return fe.handle(tlsConn, closer, &tlsState) } func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) { @@ -311,7 +338,7 @@ type Frontend struct { Protocols []string } -func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) error { +func (fe *Frontend) handle(downstream net.Conn, downstreamCloser duplexCloser, tlsState *tls.ConnectionState) error { defer downstream.Close() be := &fe.Backend @@ -319,8 +346,14 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e if err != nil { return fmt.Errorf("failed to dial backend: %v", err) } + upstreamCloser := upstream.(duplexCloser) if be.TLSConfig != nil { - upstream = tls.Client(upstream, be.TLSConfig) + upstreamTLS := tls.Client(upstream, be.TLSConfig) + upstream = upstreamTLS + upstreamCloser = tlsCloser{ + closer: upstreamCloser, + tls: upstreamTLS, + } } defer upstream.Close() @@ -348,10 +381,30 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e } } - if err := duplexCopy(upstream, downstream); err != nil { - return fmt.Errorf("failed to copy bytes: %v", err) + done := make(chan error, 2) + go func() { + _, copyErr := io.Copy(upstream, downstream) + upstreamCloser.CloseWrite() + if copyErr != nil { + done <- fmt.Errorf("failed to copy from downstream to upstream: %v", copyErr) + } else { + done <- nil + } + }() + go func() { + _, copyErr := io.Copy(downstream, upstream) + downstreamCloser.CloseWrite() + if copyErr != nil { + done <- fmt.Errorf("failed to copy from upstream to downstream: %v", copyErr) + } else { + done <- nil + } + }() + + if err := <-done; err != nil { + return err } - return nil + return <-done } type Backend struct { @@ -361,19 +414,6 @@ type Backend struct { TLSConfig *tls.Config // nil if no TLS } -func duplexCopy(a, b io.ReadWriter) error { - done := make(chan error, 2) - go func() { - _, err := io.Copy(a, b) - done <- err - }() - go func() { - _, err := io.Copy(b, a) - done <- err - }() - return <-done -} - func authorityTLV(name string) proxyproto.TLV { return proxyproto.TLV{ Type: proxyproto.PP2_TYPE_AUTHORITY,