mirror of
https://git.sr.ht/~emersion/tlstunnel
synced 2024-11-19 15:53:50 +01:00
Allow half-closed connections
This commit is contained in:
parent
8ce6fc38f2
commit
6fef9699d3
78
server.go
78
server.go
@ -241,6 +241,29 @@ func (ln *Listener) serve() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type duplexCloser interface {
|
||||||
|
CloseRead() error
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type tlsCloser struct {
|
||||||
|
closer duplexCloser
|
||||||
|
tls *tls.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tlsCloser) CloseWrite() error {
|
||||||
|
tlsErr := tc.tls.CloseWrite()
|
||||||
|
err := tc.closer.CloseWrite()
|
||||||
|
if tlsErr != nil {
|
||||||
|
return tlsErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tlsCloser) CloseRead() error {
|
||||||
|
return tc.closer.CloseRead()
|
||||||
|
}
|
||||||
|
|
||||||
func (ln *Listener) handle(conn net.Conn) error {
|
func (ln *Listener) handle(conn net.Conn) error {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
srv := ln.atomic.Load().(*listenerHandles).Server
|
srv := ln.atomic.Load().(*listenerHandles).Server
|
||||||
@ -280,7 +303,11 @@ func (ln *Listener) handle(conn net.Conn) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fe.handle(tlsConn, &tlsState)
|
closer := tlsCloser{
|
||||||
|
closer: conn.(duplexCloser),
|
||||||
|
tls: tlsConn,
|
||||||
|
}
|
||||||
|
return fe.handle(tlsConn, closer, &tlsState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
||||||
@ -311,7 +338,7 @@ type Frontend struct {
|
|||||||
Protocols []string
|
Protocols []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) error {
|
func (fe *Frontend) handle(downstream net.Conn, downstreamCloser duplexCloser, tlsState *tls.ConnectionState) error {
|
||||||
defer downstream.Close()
|
defer downstream.Close()
|
||||||
|
|
||||||
be := &fe.Backend
|
be := &fe.Backend
|
||||||
@ -319,8 +346,14 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to dial backend: %v", err)
|
return fmt.Errorf("failed to dial backend: %v", err)
|
||||||
}
|
}
|
||||||
|
upstreamCloser := upstream.(duplexCloser)
|
||||||
if be.TLSConfig != nil {
|
if be.TLSConfig != nil {
|
||||||
upstream = tls.Client(upstream, be.TLSConfig)
|
upstreamTLS := tls.Client(upstream, be.TLSConfig)
|
||||||
|
upstream = upstreamTLS
|
||||||
|
upstreamCloser = tlsCloser{
|
||||||
|
closer: upstreamCloser,
|
||||||
|
tls: upstreamTLS,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer upstream.Close()
|
defer upstream.Close()
|
||||||
|
|
||||||
@ -348,10 +381,30 @@ func (fe *Frontend) handle(downstream net.Conn, tlsState *tls.ConnectionState) e
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := duplexCopy(upstream, downstream); err != nil {
|
done := make(chan error, 2)
|
||||||
return fmt.Errorf("failed to copy bytes: %v", err)
|
go func() {
|
||||||
|
_, copyErr := io.Copy(upstream, downstream)
|
||||||
|
upstreamCloser.CloseWrite()
|
||||||
|
if copyErr != nil {
|
||||||
|
done <- fmt.Errorf("failed to copy from downstream to upstream: %v", copyErr)
|
||||||
|
} else {
|
||||||
|
done <- nil
|
||||||
}
|
}
|
||||||
return nil
|
}()
|
||||||
|
go func() {
|
||||||
|
_, copyErr := io.Copy(downstream, upstream)
|
||||||
|
downstreamCloser.CloseWrite()
|
||||||
|
if copyErr != nil {
|
||||||
|
done <- fmt.Errorf("failed to copy from upstream to downstream: %v", copyErr)
|
||||||
|
} else {
|
||||||
|
done <- nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := <-done; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return <-done
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
@ -361,19 +414,6 @@ type Backend struct {
|
|||||||
TLSConfig *tls.Config // nil if no TLS
|
TLSConfig *tls.Config // nil if no TLS
|
||||||
}
|
}
|
||||||
|
|
||||||
func duplexCopy(a, b io.ReadWriter) error {
|
|
||||||
done := make(chan error, 2)
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(a, b)
|
|
||||||
done <- err
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(b, a)
|
|
||||||
done <- err
|
|
||||||
}()
|
|
||||||
return <-done
|
|
||||||
}
|
|
||||||
|
|
||||||
func authorityTLV(name string) proxyproto.TLV {
|
func authorityTLV(name string) proxyproto.TLV {
|
||||||
return proxyproto.TLV{
|
return proxyproto.TLV{
|
||||||
Type: proxyproto.PP2_TYPE_AUTHORITY,
|
Type: proxyproto.PP2_TYPE_AUTHORITY,
|
||||||
|
Loading…
Reference in New Issue
Block a user