Remove more os.Exit calls and give a more reliable wait for stop function (attempt 3) (#1661)

This commit is contained in:
Jack Doan
2026-04-20 16:08:26 -05:00
committed by GitHub
parent 49e3c4649b
commit e80b9830a3
15 changed files with 552 additions and 94 deletions

View File

@@ -78,8 +78,20 @@ func main() {
}
if !*configTest {
ctrl.Start()
ctrl.ShutdownBlock()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -72,9 +72,21 @@ func main() {
}
if !*configTest {
ctrl.Start()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l)
ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -2,9 +2,11 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -13,6 +15,20 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
StateUnknown RunState = iota
StateReady
StateStarted
StateStopping
StateStopped
)
var ErrAlreadyStarted = errors.New("nebula is already started")
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
var ErrUnknownState = errors.New("nebula state is invalid")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +42,9 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Start actually runs nebula, this is a nonblocking call.
// The returned function blocks until nebula has fully stopped and returns the
// first fatal reader error (if any). A nil error means nebula shut down
// gracefully; a non-nil error means a reader hit an unexpected failure that
// triggered the shutdown.
func (c *Control) Start() (func() error, error) {
c.stateLock.Lock()
defer c.stateLock.Unlock()
switch c.state {
case StateReady:
//yay!
case StateStopped, StateStopping:
return nil, ErrAlreadyStopped
case StateStarted:
return nil, ErrAlreadyStarted
default:
return nil, ErrUnknownState
}
// Activate the interface
c.f.activate()
err := c.f.activate()
if err != nil {
c.state = StateStopped
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -71,16 +111,40 @@ func (c *Control) Start() {
c.lighthouseStart()
}
c.f.triggerShutdown = c.Stop
// Start reading packets.
c.f.run()
out, err := c.f.run()
if err != nil {
c.state = StateStopped
return nil, err
}
c.state = StateStarted
return out, nil
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != StateStarted {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = StateStopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -89,7 +153,9 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.l.Info("Goodbye")
c.stateLock.Lock()
c.state = StateStopped
c.stateLock.Unlock()
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -79,6 +79,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, &Interface{})
c := Control{
state: StateReady,
f: &Interface{
hostMap: hm,
},

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"io"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
@@ -87,6 +87,13 @@ type Interface struct {
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet)
fatalErr atomic.Pointer[error]
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
triggerShutdown func()
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
@@ -209,7 +216,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -237,27 +244,54 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil {
f.wg.Done()
f.inside.Close()
f.l.Fatal(err)
return err
}
return nil
}
func (f *Interface) run() {
func (f *Interface) run() (func() error, error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
f.wg.Go(func() {
f.listenOut(i)
})
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
f.wg.Go(func() {
f.listenIn(f.readers[i], i)
})
}
return func() error {
f.wg.Wait()
if e := f.fatalErr.Load(); e != nil {
return *e
}
return nil
}, nil
}
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
func (f *Interface) onFatal(err error) {
swapped := f.fatalErr.CompareAndSwap(nil, &err)
if !swapped {
return
}
if f.triggerShutdown != nil {
f.triggerShutdown()
}
}
@@ -276,9 +310,16 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading inbound packet, closing")
f.onFatal(err)
}
f.l.Debugf("underlay reader %v is done", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -292,17 +333,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for {
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
f.onFatal(err)
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
f.l.Debugf("overlay reader %v is done", i)
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -477,23 +518,23 @@ func (f *Interface) GetCertState() *CertState {
}
func (f *Interface) Close() error {
var errs []error
f.closed.Store(true)
for _, u := range f.writers {
// Release the udp readers
for i, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
}
}
for i, r := range f.readers {
if i == 0 {
continue // f.readers[0] is f.inside, which we want to save for last
}
if err := r.Close(); err != nil {
f.l.WithError(err).Error("Error while closing tun reader")
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket")
errs = append(errs, err)
}
}
// Release the tun device
return f.inside.Close()
// Release the tun device (closing the tun also closes all readers)
closeErr := f.inside.Close()
if closeErr != nil {
errs = append(errs, closeErr)
}
f.wg.Done()
return errors.Join(errs...)
}

19
main.go
View File

@@ -288,15 +288,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
state: StateReady,
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
}, nil
}

View File

@@ -0,0 +1,120 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
import (
"errors"
"os"
"sync"
"testing"
"time"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
}
t.Cleanup(func() {
_ = friend.Close()
_ = parent.Close()
})
readers := []*tunFile{parent, friend}
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r *tunFile) {
defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64))
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}

View File

@@ -4,6 +4,7 @@
package overlay
import (
"encoding/binary"
"fmt"
"io"
"net"
@@ -24,9 +25,175 @@ import (
"golang.org/x/sys/unix"
)
type tun struct {
io.ReadWriteCloser
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct {
fd int
shutdownFd int
lastOne bool
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool
}
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
return &tunFile{
fd: fd,
shutdownFd: r.shutdownFd,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
}, nil
}
func newTunFd(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
lastOne: true,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) Read(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) Write(buf []byte) (int, error) {
for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err
}
func (r *tunFile) Close() error {
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
return nil
}
r.closed = true
if r.lastOne {
_ = unix.Close(r.shutdownFd)
}
return unix.Close(r.fd)
}
type tun struct {
*tunFile
readers []*tunFile
closeLock sync.Mutex
Device string
vpnNetworks []netip.Prefix
MaxMTU int
@@ -72,9 +239,7 @@ type ifreqQLEN struct {
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
if err != nil {
return nil, err
}
@@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, &NameError{
Name: nameStr,
Underlying: err,
@@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, fd, vpnNetworks)
if err != nil {
return nil, err
}
@@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
tunFile: tfd,
readers: []*tunFile{tfd},
closeLock: sync.Mutex{},
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
@@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
l: l,
}
err := t.reload(c, true)
if err != nil {
if err = t.reload(c, true); err != nil {
_ = t.Close()
return nil, err
}
@@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool {
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
t.closeLock.Lock()
defer t.closeLock.Unlock()
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
out, err := t.tunFile.newFriend(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
return file, nil
t.readers = append(t.readers, out)
return out, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -684,18 +866,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
}
func (t *tun) Close() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
if t.routeChan != nil {
close(t.routeChan)
t.routeChan = nil
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}
// Signal all readers blocked in poll to wake up and exit
_ = t.tunFile.wakeForShutdown()
if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
_ = unix.Close(int(t.ioctlFd))
t.ioctlFd = 0
}
return nil
for i := range t.readers {
if i == 0 {
continue //we want to close the zeroth reader last
}
err := t.readers[i].Close()
if err != nil {
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
} else {
t.l.WithField("reader", i).Info("closed tun reader")
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
} else {
t.l.WithField("reader", 0).Info("closed tun reader")
}
return err
}

View File

@@ -44,7 +44,10 @@ type Service struct {
}
func New(control *nebula.Control) (*Service, error) {
control.Start()
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group so a fatal reader error
// propagates out through errgroup.Wait().
eg.Go(func() error {
return wait()
})
return &s, nil
}

View File

@@ -16,7 +16,7 @@ type EncReader func(
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader)
ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
@@ -31,8 +31,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) {
return
func (NoopConn) ListenOut(_ EncReader) error {
return nil
}
func (NoopConn) SupportsMultipleReaders() bool {
return false

View File

@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
@@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) {
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
u.l.WithError(err).Error("unexpected udp socket receive error")

View File

@@ -73,7 +73,7 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) {
func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
var lastRecvErr time.Time
@@ -83,8 +83,7 @@ func (u *GenericConn) ListenOut(r EncReader) {
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
// Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {

View File

@@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
return int(n), true, nil
}
func (u *StdConn) listenOutSingle(r EncReader) {
func (u *StdConn) listenOutSingle(r EncReader) error {
var err error
var n int
var from netip.AddrPort
@@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) {
for {
n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
r(from, buffer[:n])
}
}
func (u *StdConn) listenOutBatch(r EncReader) {
func (u *StdConn) listenOutBatch(r EncReader) error {
var ip netip.Addr
var n int
var operr error
@@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) {
for {
err := u.rawConn.Read(reader)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
if operr != nil {
u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop")
return
return operr
}
for i := 0; i < n; i++ {
@@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) {
}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
if u.batch == 1 {
//save some ram by not calling PrepareRawMessages for fields we won't use
//we could also make this path more common by calling recvmmsg with msgs[:1],
//but that's still the recvmmsg syscall, which would be a change
u.listenOutSingle(r)
return u.listenOutSingle(r)
} else {
u.listenOutBatch(r)
return u.listenOutBatch(r)
}
}

View File

@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) {
func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
var lastRecvErr time.Time
@@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
// Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {

View File

@@ -6,6 +6,7 @@ package udp
import (
"io"
"net/netip"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) ListenOut(r EncReader) {
func (u *TesterConn) ListenOut(r EncReader) error {
for {
p, ok := <-u.RxPackets
if !ok {
return
return os.ErrClosed
}
r(p.From, p.Data)
}