diff --git a/server.go b/server.go index 369c3bc..a8ff77b 100644 --- a/server.go +++ b/server.go @@ -282,14 +282,17 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error { return err } tempDelay = 0 - go srv.ServeConn(ctx, rw) + go srv.serveConn(ctx, rw, false) } } -func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool { +func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc, external bool) bool { srv.mu.Lock() defer srv.mu.Unlock() - if srv.closed && !srv.shutdown { + // Reject the connection under the following conditions: + // - Shutdown or Close has been called and conn is external (from ServeConn) + // - Close (not Shutdown) has been called and conn is internal (from Serve) + if srv.closed && (external || !srv.shutdown) { return false } if srv.conns == nil { @@ -309,15 +312,17 @@ func (srv *Server) deleteConn(conn *net.Conn) { // It closes the connection when the response has been completed. // If the provided context expires before the response has completed, // ServeConn closes the connection and returns the context's error. -// -// Note that ServeConn can be used during a Shutdown. func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { + return srv.serveConn(ctx, conn, true) +} + +func (srv *Server) serveConn(ctx context.Context, conn net.Conn, external bool) error { defer conn.Close() ctx, cancel := context.WithCancel(ctx) defer cancel() - if !srv.trackConn(&conn, cancel) { + if !srv.trackConn(&conn, cancel, external) { return context.Canceled } defer srv.tryCloseDone() @@ -332,7 +337,7 @@ func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { errch := make(chan error, 1) go func() { - errch <- srv.serveConn(ctx, conn) + errch <- srv.goServeConn(ctx, conn) }() select { @@ -343,7 +348,7 @@ func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { } } -func (srv *Server) serveConn(ctx context.Context, conn net.Conn) error { +func (srv *Server) goServeConn(ctx context.Context, conn net.Conn) error { ctx, cancel := context.WithCancel(ctx) done := ctx.Done() cw := &contextWriter{