diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5db72555..388b17d0 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -7,7 +7,7 @@ import ( "io" "net/netip" "os" - "sync/atomic" + "sync" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -37,8 +37,16 @@ type TesterConn struct { RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - closed atomic.Bool - l *logrus.Logger + // done is closed exactly once by Close. Senders select on it so they + // never race with a channel close; readers exit when it fires. The + // packet channels are intentionally never closed - that was the source + // of `send on closed channel` panics when a WriteTo/Send from another + // goroutine passed the close check and reached the send just after + // Close ran. + done chan struct{} + closeOnce sync.Once + + l *logrus.Logger } func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { @@ -46,6 +54,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), + done: make(chan struct{}), l: l, }, nil } @@ -54,10 +63,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - if u.closed.Load() { - return - } - h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) @@ -68,7 +73,10 @@ func (u *TesterConn) Send(packet *Packet) { WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } - u.RxPackets <- packet + select { + case <-u.done: + case u.RxPackets <- packet: + } } // Get will pull a UdpPacket from the transmit queue @@ -76,7 +84,12 @@ func (u *TesterConn) Send(packet *Packet) { // packets were ingested from the tun side (in most cases), you can send them with Tun.Send func (u *TesterConn) Get(block bool) *Packet { if block { - return <-u.TxPackets + select { + case <-u.done: + return nil + case p := <-u.TxPackets: + return p + } } select { @@ -92,10 +105,6 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - if u.closed.Load() { - return io.ErrClosedPipe - } - p := &Packet{ Data: make([]byte, len(b), len(b)), From: u.Addr, @@ -103,17 +112,22 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { } copy(p.Data, b) - u.TxPackets <- p - return nil + select { + case <-u.done: + return io.ErrClosedPipe + case u.TxPackets <- p: + return nil + } } func (u *TesterConn) ListenOut(r EncReader) error { for { - p, ok := <-u.RxPackets - if !ok { + select { + case <-u.done: return os.ErrClosed + case p := <-u.RxPackets: + r(p.From, p.Data) } - r(p.From, p.Data) } } @@ -137,9 +151,8 @@ func (u *TesterConn) Rebind() error { } func (u *TesterConn) Close() error { - if u.closed.CompareAndSwap(false, true) { - close(u.RxPackets) - close(u.TxPackets) - } + u.closeOnce.Do(func() { + close(u.done) + }) return nil }