1
1
mirror of https://git.sr.ht/~emersion/tlstunnel synced 2024-11-19 15:53:50 +01:00

Allow half-closed connections

This commit is contained in:
Simon Ser 2021-02-18 17:05:53 +01:00
parent 8ce6fc38f2
commit 6fef9699d3

@ -241,6 +241,29 @@ func (ln *Listener) serve() error {
} }
} }
type duplexCloser interface {
CloseRead() error
CloseWrite() error
}
type tlsCloser struct {
closer duplexCloser
tls *tls.Conn
}
func (tc tlsCloser) CloseWrite() error {
tlsErr := tc.tls.CloseWrite()
err := tc.closer.CloseWrite()
if tlsErr != nil {
return tlsErr
}
return err
}
func (tc tlsCloser) CloseRead() error {
return tc.closer.CloseRead()
}
func (ln *Listener) handle(conn net.Conn) error { func (ln *Listener) handle(conn net.Conn) error {
defer conn.Close() defer conn.Close()
srv := ln.atomic.Load().(*listenerHandles).Server srv := ln.atomic.Load().(*listenerHandles).Server
@ -280,7 +303,11 @@ func (ln *Listener) handle(conn net.Conn) error {
return err return err
} }
return fe.handle(tlsConn, &tlsState) closer := tlsCloser{
closer: conn.(duplexCloser),
tls: tlsConn,
}
return fe.handle(tlsConn, closer, &tlsState)
} }
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) { func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
@ -311,7 +338,7 @@ type Frontend struct {
Protocols []string Protocols []string
} }
func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) error { func (fe *Frontend) handle(downstream net.Conn, downstreamCloser duplexCloser, tlsState *tls.ConnectionState) error {
defer downstream.Close() defer downstream.Close()
be := &fe.Backend be := &fe.Backend
@ -319,8 +346,14 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
if err != nil { if err != nil {
return fmt.Errorf("failed to dial backend: %v", err) return fmt.Errorf("failed to dial backend: %v", err)
} }
upstreamCloser := upstream.(duplexCloser)
if be.TLSConfig != nil { if be.TLSConfig != nil {
upstream = tls.Client(upstream, be.TLSConfig) upstreamTLS := tls.Client(upstream, be.TLSConfig)
upstream = upstreamTLS
upstreamCloser = tlsCloser{
closer: upstreamCloser,
tls: upstreamTLS,
}
} }
defer upstream.Close() defer upstream.Close()
@ -348,10 +381,30 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
} }
} }
if err := duplexCopy(upstream, downstream); err != nil { done := make(chan error, 2)
return fmt.Errorf("failed to copy bytes: %v", err) go func() {
_, copyErr := io.Copy(upstream, downstream)
upstreamCloser.CloseWrite()
if copyErr != nil {
done <- fmt.Errorf("failed to copy from downstream to upstream: %v", copyErr)
} else {
done <- nil
} }
return nil }()
go func() {
_, copyErr := io.Copy(downstream, upstream)
downstreamCloser.CloseWrite()
if copyErr != nil {
done <- fmt.Errorf("failed to copy from upstream to downstream: %v", copyErr)
} else {
done <- nil
}
}()
if err := <-done; err != nil {
return err
}
return <-done
} }
type Backend struct { type Backend struct {
@ -361,19 +414,6 @@ type Backend struct {
TLSConfig *tls.Config // nil if no TLS TLSConfig *tls.Config // nil if no TLS
} }
func duplexCopy(a, b io.ReadWriter) error {
done := make(chan error, 2)
go func() {
_, err := io.Copy(a, b)
done <- err
}()
go func() {
_, err := io.Copy(b, a)
done <- err
}()
return <-done
}
func authorityTLV(name string) proxyproto.TLV { func authorityTLV(name string) proxyproto.TLV {
return proxyproto.TLV{ return proxyproto.TLV{
Type: proxyproto.PP2_TYPE_AUTHORITY, Type: proxyproto.PP2_TYPE_AUTHORITY,