From 6592a07b5136b4973b4fe7d866115d884ee4c81f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 2 Apr 2025 09:51:59 -0500 Subject: [PATCH] Remove more os.Exit calls and give a more reliable wait for stop function --- cmd/nebula-service/main.go | 12 ++++++-- cmd/nebula/main.go | 12 ++++++-- control.go | 56 ++++++++++++++++++++++++++++++++++---- interface.go | 52 +++++++++++++++++++++++++---------- main.go | 18 ++++++------ overlay/tun_linux.go | 19 +++++++++---- service/service.go | 11 +++++++- udp/conn.go | 6 ++-- udp/udp_darwin.go | 5 ++-- udp/udp_generic.go | 17 ++---------- udp/udp_linux.go | 24 ++++++---------- udp/udp_rio_windows.go | 5 ++-- udp/udp_tester.go | 5 ++-- 13 files changed, 161 insertions(+), 81 deletions(-) diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..efbcc8b8 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -78,8 +78,16 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..fa2c5e7f 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -72,9 +72,17 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/control.go b/control.go index f8567b50..2d07de59 100644 --- a/control.go +++ b/control.go @@ -2,9 +2,11 @@ package nebula import ( "context" + "errors" "net/netip" "os" "os/signal" + "sync" "syscall" "github.com/sirupsen/logrus" @@ -13,6 +15,16 @@ import ( "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + Stopped RunState = 0 // The control has yet to be started + Started RunState = 1 // The control has been started + Stopping RunState = 2 // The control is stopping +) + +var ErrAlreadyStarted = errors.New("nebula is already started") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,6 +38,9 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface l *logrus.Logger ctx context.Context @@ -49,10 +64,21 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function can be used to wait for nebula to fully stop. +func (c *Control) Start() (func(), error) { + c.stateLock.Lock() + if c.state != Stopped { + c.stateLock.Unlock() + return nil, ErrAlreadyStarted + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.stateLock.Unlock() + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -72,15 +98,33 @@ func (c *Control) Start() { } // Start reading packets. - c.f.run() + c.state = Started + c.stateLock.Unlock() + return c.f.run() +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != Started { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = Stopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() @@ -89,7 +133,7 @@ func (c *Control) Stop() { if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } - c.l.Info("Goodbye") + c.state = Stopped } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled diff --git a/interface.go b/interface.go index 61f8c9b7..97d9bf0c 100644 --- a/interface.go +++ b/interface.go @@ -7,6 +7,7 @@ import ( "io" "net/netip" "os" + "sync" "sync/atomic" "time" @@ -87,6 +88,7 @@ type Interface struct { writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -209,7 +211,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -237,28 +239,36 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + if err = f.inside.Activate(); err != nil { f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func(), error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { - go f.listenOut(i) + f.wg.Go(func() { + f.listenOut(i) + }) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { - go f.listenIn(f.readers[i], i) + f.wg.Go(func() { + f.listenIn(f.readers[i], i) + }) } + + return f.wg.Wait, nil } func (f *Interface) listenOut(i int) { @@ -276,9 +286,17 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + + if err != nil && !f.closed.Load() { + f.l.WithError(err).Error("Error while reading packet inbound packet, closing") + //TODO: Trigger Control to close + } + + f.l.Debugf("underlay reader %v is done", i) + f.wg.Done() } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -292,17 +310,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") + //TODO: Trigger Control to close } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + + f.l.Debugf("overlay reader %v is done", i) + f.wg.Done() } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -479,15 +498,18 @@ func (f *Interface) GetCertState() *CertState { func (f *Interface) Close() error { f.closed.Store(true) + // Release the udp readers for _, u := range f.writers { err := u.Close() if err != nil { f.l.WithError(err).Error("Error while closing udp socket") } } + + // Release the tun readers for i, r := range f.readers { if i == 0 { - continue // f.readers[0] is f.inside, which we want to save for last + continue // f.readers[0] is f.inside, which we want to save for last, since it closes other stuff too } if err := r.Close(); err != nil { f.l.WithError(err).Error("Error while closing tun reader") diff --git a/main.go b/main.go index 74979417..8d17c8ea 100644 --- a/main.go +++ b/main.go @@ -288,15 +288,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: statsStart, + dnsStart: dnsStart, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9d779a4b..96f32079 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -686,16 +686,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) Close() error { if t.routeChan != nil { close(t.routeChan) - } - - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() + t.routeChan = nil } if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + err := os.NewFile(t.ioctlFd, "ioctlFd").Close() + if err != nil { + t.l.WithField("error", err).Error("Failed to close ioctl fd") + } t.ioctlFd = 0 } + if t.ReadWriteCloser != nil { + err := t.ReadWriteCloser.Close() + if err != nil { + t.l.WithField("error", err).Error("Failed to close tun file") + return err + } + t.ReadWriteCloser = nil + } + return nil } diff --git a/service/service.go b/service/service.go index fc8ac97a..c86d08c3 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group + eg.Go(func() error { + wait() + return nil + }) + return &s, nil } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..863c98f3 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } u.l.WithError(err).Error("unexpected udp socket receive error") diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..44632fed 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -10,11 +10,9 @@ package udp import ( "context" - "errors" "fmt" "net" "net/netip" - "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -73,25 +71,14 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) - var lastRecvErr time.Time - for { // Just read one packet at a time n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { - if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return - } - // Dampen unexpected message warns to once per minute - if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { - lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") - } - continue + return err } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index b1490a1c..4c24b09d 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) { +func (u *StdConn) listenOutSingle(r EncReader) error { var err error var n int var from netip.AddrPort @@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) { for { n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) } } -func (u *StdConn) listenOutBatch(r EncReader) { +func (u *StdConn) listenOutBatch(r EncReader) error { var ip netip.Addr var n int var operr error @@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) { for { err := u.rawConn.Read(reader) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } if operr != nil { - u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") - return + return operr } for i := 0; i < n; i++ { @@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) { } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { if u.batch == 1 { - //save some ram by not calling PrepareRawMessages for fields we won't use - //we could also make this path more common by calling recvmmsg with msgs[:1], - //but that's still the recvmmsg syscall, which would be a change - u.listenOutSingle(r) + return u.listenOutSingle(r) } else { - u.listenOutBatch(r) + return u.listenOutBatch(r) } } @@ -290,7 +284,7 @@ func (u *StdConn) ReloadConfig(c *config.C) { } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { - var vallen uint32 = 4 * unix.SK_MEMINFO_VARS + const vallen uint32 = 4 * unix.SK_MEMINFO_VARS if u.rawConn == nil { return fmt.Errorf("no UDP connection") diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..607b978e 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) { if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..5db72555 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -6,6 +6,7 @@ package udp import ( "io" "net/netip" + "os" "sync/atomic" "github.com/sirupsen/logrus" @@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets if !ok { - return + return os.ErrClosed } r(p.From, p.Data) }