From e80b9830a3a7aa0a7080fd6ebfd53b22cc70e6e4 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Mon, 20 Apr 2026 16:08:26 -0500 Subject: [PATCH] Remove more os.Exit calls and give a more reliable wait for stop function (attempt 3) (#1661) --- cmd/nebula-service/main.go | 16 ++- cmd/nebula/main.go | 16 ++- control.go | 78 ++++++++++- control_test.go | 1 + interface.go | 95 +++++++++---- main.go | 19 +-- overlay/tun_file_linux_test.go | 120 ++++++++++++++++ overlay/tun_linux.go | 242 ++++++++++++++++++++++++++++++--- service/service.go | 11 +- udp/conn.go | 6 +- udp/udp_darwin.go | 5 +- udp/udp_generic.go | 5 +- udp/udp_linux.go | 22 ++- udp/udp_rio_windows.go | 5 +- udp/udp_tester.go | 5 +- 15 files changed, 552 insertions(+), 94 deletions(-) create mode 100644 overlay/tun_file_linux_test.go diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..aaec80f7 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -78,8 +78,20 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.WithError(err).Error("Nebula stopped due to fatal error") + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..f29f4537 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -72,9 +72,21 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.WithError(err).Error("Nebula stopped due to fatal error") + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/control.go b/control.go index f8567b50..75eccef1 100644 --- a/control.go +++ b/control.go @@ -2,9 +2,11 @@ package nebula import ( "context" + "errors" "net/netip" "os" "os/signal" + "sync" "syscall" "github.com/sirupsen/logrus" @@ -13,6 +15,20 @@ import ( "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + StateUnknown RunState = iota + StateReady + StateStarted + StateStopping + StateStopped +) + +var ErrAlreadyStarted = errors.New("nebula is already started") +var ErrAlreadyStopped = errors.New("nebula cannot be restarted") +var ErrUnknownState = errors.New("nebula state is invalid") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,6 +42,9 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface l *logrus.Logger ctx context.Context @@ -49,10 +68,31 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function blocks until nebula has fully stopped and returns the +// first fatal reader error (if any). A nil error means nebula shut down +// gracefully; a non-nil error means a reader hit an unexpected failure that +// triggered the shutdown. +func (c *Control) Start() (func() error, error) { + c.stateLock.Lock() + defer c.stateLock.Unlock() + switch c.state { + case StateReady: + //yay! + case StateStopped, StateStopping: + return nil, ErrAlreadyStopped + case StateStarted: + return nil, ErrAlreadyStarted + default: + return nil, ErrUnknownState + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.state = StateStopped + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -71,16 +111,40 @@ func (c *Control) Start() { c.lighthouseStart() } + c.f.triggerShutdown = c.Stop + // Start reading packets. - c.f.run() + out, err := c.f.run() + if err != nil { + c.state = StateStopped + return nil, err + } + c.state = StateStarted + return out, nil +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != StateStarted { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = StateStopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() @@ -89,7 +153,9 @@ func (c *Control) Stop() { if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } - c.l.Info("Goodbye") + c.stateLock.Lock() + c.state = StateStopped + c.stateLock.Unlock() } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled diff --git a/control_test.go b/control_test.go index e8a5d312..558d8669 100644 --- a/control_test.go +++ b/control_test.go @@ -79,6 +79,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, &Interface{}) c := Control{ + state: StateReady, f: &Interface{ hostMap: hm, }, diff --git a/interface.go b/interface.go index 61f8c9b7..9e7a98a9 100644 --- a/interface.go +++ b/interface.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net/netip" - "os" + "sync" "sync/atomic" "time" @@ -87,6 +87,13 @@ type Interface struct { writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup + + // fatalErr holds the first unexpected reader error that caused shutdown. + // nil means "no fatal error" (yet) + fatalErr atomic.Pointer[error] + // triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr + triggerShutdown func() metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -209,7 +216,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -237,27 +244,54 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + f.wg.Add(1) // for us to wait on Close() to return + if err = f.inside.Activate(); err != nil { + f.wg.Done() f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func() error, error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { - go f.listenOut(i) + f.wg.Go(func() { + f.listenOut(i) + }) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { - go f.listenIn(f.readers[i], i) + f.wg.Go(func() { + f.listenIn(f.readers[i], i) + }) + } + + return func() error { + f.wg.Wait() + if e := f.fatalErr.Load(); e != nil { + return *e + } + return nil + }, nil +} + +// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one +func (f *Interface) onFatal(err error) { + swapped := f.fatalErr.CompareAndSwap(nil, &err) + if !swapped { + return + } + if f.triggerShutdown != nil { + f.triggerShutdown() } } @@ -276,9 +310,16 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + + if err != nil && !f.closed.Load() { + f.l.WithError(err).Error("Error while reading inbound packet, closing") + f.onFatal(err) + } + + f.l.Debugf("underlay reader %v is done", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -292,17 +333,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") + f.onFatal(err) } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + + f.l.Debugf("overlay reader %v is done", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -477,23 +518,23 @@ func (f *Interface) GetCertState() *CertState { } func (f *Interface) Close() error { + var errs []error f.closed.Store(true) - for _, u := range f.writers { + // Release the udp readers + for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).Error("Error while closing udp socket") - } - } - for i, r := range f.readers { - if i == 0 { - continue // f.readers[0] is f.inside, which we want to save for last - } - if err := r.Close(); err != nil { - f.l.WithError(err).Error("Error while closing tun reader") + f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket") + errs = append(errs, err) } } - // Release the tun device - return f.inside.Close() + // Release the tun device (closing the tun also closes all readers) + closeErr := f.inside.Close() + if closeErr != nil { + errs = append(errs, closeErr) + } + f.wg.Done() + return errors.Join(errs...) } diff --git a/main.go b/main.go index 74979417..8adc2921 100644 --- a/main.go +++ b/main.go @@ -288,15 +288,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + state: StateReady, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: statsStart, + dnsStart: dnsStart, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go new file mode 100644 index 00000000..5ab87e05 --- /dev/null +++ b/overlay/tun_file_linux_test.go @@ -0,0 +1,120 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + t.Cleanup(func() { _ = tf.Close() }) + + done := make(chan error, 1) + go func() { + _, err := tf.Read(make([]byte, 64)) + done <- err + }() + + // Verify Read is actually blocked in poll. + select { + case err := <-done: + t.Fatalf("Read returned before shutdown signal: %v", err) + case <-time.After(50 * time.Millisecond): + } + + if err := tf.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, os.ErrClosed) { + t.Fatalf("expected os.ErrClosed, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Read did not wake on shutdown") + } +} + +func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { + parent, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + friend, err := parent.newFriend(newReadPipe(t)) + if err != nil { + _ = parent.Close() + t.Fatalf("newFriend: %v", err) + } + t.Cleanup(func() { + _ = friend.Close() + _ = parent.Close() + }) + + readers := []*tunFile{parent, friend} + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r *tunFile) { + defer wg.Done() + _, errs[i] = r.Read(make([]byte, 64)) + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestTunFile_Close_Idempotent(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9d779a4b..2830ff6b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,6 +4,7 @@ package overlay import ( + "encoding/binary" "fmt" "io" "net" @@ -24,9 +25,175 @@ import ( "golang.org/x/sys/unix" ) +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFile struct { + fd int + shutdownFd int + lastOne bool + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed bool +} + +// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun +func (r *tunFile) newFriend(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + return &tunFile{ + fd: fd, + shutdownFd: r.shutdownFd, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + }, nil +} + +func newTunFd(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &tunFile{ + fd: fd, + shutdownFd: shutdownFd, + lastOne: true, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + + return out, nil +} + +func (r *tunFile) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) Read(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnRead(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) Write(buf []byte) (int, error) { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) + return err +} + +func (r *tunFile) Close() error { + if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem + return nil + } + r.closed = true + if r.lastOne { + _ = unix.Close(r.shutdownFd) + } + return unix.Close(r.fd) +} + type tun struct { - io.ReadWriteCloser - fd int + *tunFile + readers []*tunFile + closeLock sync.Mutex Device string vpnNetworks []netip.Prefix MaxMTU int @@ -72,9 +239,7 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err } @@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, &NameError{ Name: nameStr, Underlying: err, @@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { return nil, err } @@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. +func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + tunFile: tfd, + readers: []*tunFile{tfd}, + closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) - if err != nil { + if err = t.reload(c, true); err != nil { + _ = t.Close() return nil, err } @@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + t.closeLock.Lock() + defer t.closeLock.Unlock() + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + out, err := t.tunFile.newFriend(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } - return file, nil + t.readers = append(t.readers, out) + + return out, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -684,18 +866,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } func (t *tun) Close() error { + t.closeLock.Lock() + defer t.closeLock.Unlock() + if t.routeChan != nil { close(t.routeChan) + t.routeChan = nil } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() - } + // Signal all readers blocked in poll to wake up and exit + _ = t.tunFile.wakeForShutdown() if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - return nil + for i := range t.readers { + if i == 0 { + continue //we want to close the zeroth reader last + } + err := t.readers[i].Close() + if err != nil { + t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") + } else { + t.l.WithField("reader", i).Info("closed tun reader") + } + } + + //this is t.readers[0] too + err := t.tunFile.Close() + if err != nil { + t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") + } else { + t.l.WithField("reader", 0).Info("closed tun reader") + } + return err } diff --git a/service/service.go b/service/service.go index fc8ac97a..899e851d 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group so a fatal reader error + // propagates out through errgroup.Wait(). + eg.Go(func() error { + return wait() + }) + return &s, nil } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..863c98f3 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } u.l.WithError(err).Error("unexpected udp socket receive error") diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..ad26f794 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -73,7 +73,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -83,8 +83,7 @@ func (u *GenericConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { diff --git a/udp/udp_linux.go b/udp/udp_linux.go index b1490a1c..21a34147 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) { +func (u *StdConn) listenOutSingle(r EncReader) error { var err error var n int var from netip.AddrPort @@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) { for { n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) } } -func (u *StdConn) listenOutBatch(r EncReader) { +func (u *StdConn) listenOutBatch(r EncReader) error { var ip netip.Addr var n int var operr error @@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) { for { err := u.rawConn.Read(reader) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } if operr != nil { - u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") - return + return operr } for i := 0; i < n; i++ { @@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) { } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { if u.batch == 1 { - //save some ram by not calling PrepareRawMessages for fields we won't use - //we could also make this path more common by calling recvmmsg with msgs[:1], - //but that's still the recvmmsg syscall, which would be a change - u.listenOutSingle(r) + return u.listenOutSingle(r) } else { - u.listenOutBatch(r) + return u.listenOutBatch(r) } } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..607b978e 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) { if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..5db72555 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -6,6 +6,7 @@ package udp import ( "io" "net/netip" + "os" "sync/atomic" "github.com/sirupsen/logrus" @@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets if !ok { - return + return os.ErrClosed } r(p.From, p.Data) }