mirror of
https://git.sr.ht/~emersion/tlstunnel
synced 2024-11-19 15:53:50 +01:00
Allow half-closed connections
This commit is contained in:
parent
8ce6fc38f2
commit
6fef9699d3
78
server.go
78
server.go
@ -241,6 +241,29 @@ func (ln *Listener) serve() error {
|
||||
}
|
||||
}
|
||||
|
||||
type duplexCloser interface {
|
||||
CloseRead() error
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type tlsCloser struct {
|
||||
closer duplexCloser
|
||||
tls *tls.Conn
|
||||
}
|
||||
|
||||
func (tc tlsCloser) CloseWrite() error {
|
||||
tlsErr := tc.tls.CloseWrite()
|
||||
err := tc.closer.CloseWrite()
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (tc tlsCloser) CloseRead() error {
|
||||
return tc.closer.CloseRead()
|
||||
}
|
||||
|
||||
func (ln *Listener) handle(conn net.Conn) error {
|
||||
defer conn.Close()
|
||||
srv := ln.atomic.Load().(*listenerHandles).Server
|
||||
@ -280,7 +303,11 @@ func (ln *Listener) handle(conn net.Conn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return fe.handle(tlsConn, &tlsState)
|
||||
closer := tlsCloser{
|
||||
closer: conn.(duplexCloser),
|
||||
tls: tlsConn,
|
||||
}
|
||||
return fe.handle(tlsConn, closer, &tlsState)
|
||||
}
|
||||
|
||||
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
||||
@ -311,7 +338,7 @@ type Frontend struct {
|
||||
Protocols []string
|
||||
}
|
||||
|
||||
func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) error {
|
||||
func (fe *Frontend) handle(downstream net.Conn, downstreamCloser duplexCloser, tlsState *tls.ConnectionState) error {
|
||||
defer downstream.Close()
|
||||
|
||||
be := &fe.Backend
|
||||
@ -319,8 +346,14 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial backend: %v", err)
|
||||
}
|
||||
upstreamCloser := upstream.(duplexCloser)
|
||||
if be.TLSConfig != nil {
|
||||
upstream = tls.Client(upstream, be.TLSConfig)
|
||||
upstreamTLS := tls.Client(upstream, be.TLSConfig)
|
||||
upstream = upstreamTLS
|
||||
upstreamCloser = tlsCloser{
|
||||
closer: upstreamCloser,
|
||||
tls: upstreamTLS,
|
||||
}
|
||||
}
|
||||
defer upstream.Close()
|
||||
|
||||
@ -348,10 +381,30 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
|
||||
}
|
||||
}
|
||||
|
||||
if err := duplexCopy(upstream, downstream); err != nil {
|
||||
return fmt.Errorf("failed to copy bytes: %v", err)
|
||||
done := make(chan error, 2)
|
||||
go func() {
|
||||
_, copyErr := io.Copy(upstream, downstream)
|
||||
upstreamCloser.CloseWrite()
|
||||
if copyErr != nil {
|
||||
done <- fmt.Errorf("failed to copy from downstream to upstream: %v", copyErr)
|
||||
} else {
|
||||
done <- nil
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
go func() {
|
||||
_, copyErr := io.Copy(downstream, upstream)
|
||||
downstreamCloser.CloseWrite()
|
||||
if copyErr != nil {
|
||||
done <- fmt.Errorf("failed to copy from upstream to downstream: %v", copyErr)
|
||||
} else {
|
||||
done <- nil
|
||||
}
|
||||
}()
|
||||
|
||||
if err := <-done; err != nil {
|
||||
return err
|
||||
}
|
||||
return <-done
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
@ -361,19 +414,6 @@ type Backend struct {
|
||||
TLSConfig *tls.Config // nil if no TLS
|
||||
}
|
||||
|
||||
func duplexCopy(a, b io.ReadWriter) error {
|
||||
done := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(a, b)
|
||||
done <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(b, a)
|
||||
done <- err
|
||||
}()
|
||||
return <-done
|
||||
}
|
||||
|
||||
func authorityTLV(name string) proxyproto.TLV {
|
||||
return proxyproto.TLV{
|
||||
Type: proxyproto.PP2_TYPE_AUTHORITY,
|
||||
|
Loading…
Reference in New Issue
Block a user