wait for goroutines to finish and for tun to actually be closed

This commit is contained in:
JackDoan
2026-04-16 13:19:25 -05:00
parent 183c1e3cfd
commit 6b2e6d9f55
6 changed files with 57 additions and 16 deletions

View File

@@ -85,7 +85,12 @@ func main() {
} }
go ctrl.ShutdownBlock() go ctrl.ShutdownBlock()
wait()
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
l.Info("Goodbye")
os.Exit(2)
}
l.Info("Goodbye") l.Info("Goodbye")
} }

View File

@@ -80,7 +80,12 @@ func main() {
go ctrl.ShutdownBlock() go ctrl.ShutdownBlock()
notifyReady(l) notifyReady(l)
wait()
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
l.Info("Goodbye")
os.Exit(2)
}
l.Info("Goodbye") l.Info("Goodbye")
} }

View File

@@ -65,8 +65,11 @@ type ControlHostInfo struct {
} }
// Start actually runs nebula, this is a nonblocking call. // Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop. // The returned function blocks until nebula has fully stopped and returns the
func (c *Control) Start() (func(), error) { // first fatal reader error (if any). A nil error means nebula shut down
// gracefully; a non-nil error means a reader hit an unexpected failure that
// triggered the shutdown.
func (c *Control) Start() (func() error, error) {
c.stateLock.Lock() c.stateLock.Lock()
if c.state != Stopped { if c.state != Stopped {
c.stateLock.Unlock() c.stateLock.Unlock()
@@ -97,6 +100,8 @@ func (c *Control) Start() (func(), error) {
c.lighthouseStart() c.lighthouseStart()
} }
c.f.triggerShutdown = c.Stop
// Start reading packets. // Start reading packets.
c.state = Started c.state = Started
c.stateLock.Unlock() c.stateLock.Unlock()

View File

@@ -89,6 +89,12 @@ type Interface struct {
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet)
fatalErr atomic.Pointer[error]
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
triggerShutdown func()
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
@@ -244,6 +250,7 @@ func (f *Interface) activate() error {
f.readers[i] = reader f.readers[i] = reader
} }
f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
f.inside.Close() f.inside.Close()
return err return err
@@ -252,7 +259,7 @@ func (f *Interface) activate() error {
return nil return nil
} }
func (f *Interface) run() (func(), error) { func (f *Interface) run() (func() error, error) {
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
f.wg.Go(func() { f.wg.Go(func() {
@@ -267,7 +274,24 @@ func (f *Interface) run() (func(), error) {
}) })
} }
return f.wg.Wait, nil return func() error {
f.wg.Wait()
if e := f.fatalErr.Load(); e != nil {
return *e
}
return nil
}, nil
}
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
func (f *Interface) onFatal(err error) {
swapped := f.fatalErr.CompareAndSwap(nil, &err)
if !swapped {
return
}
if f.triggerShutdown != nil {
f.triggerShutdown()
}
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
@@ -291,7 +315,7 @@ func (f *Interface) listenOut(i int) {
if err != nil && !f.closed.Load() { if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading inbound packet, closing") f.l.WithError(err).Error("Error while reading inbound packet, closing")
//TODO: Trigger Control to close f.onFatal(err)
} }
f.l.Infof("underlay reader %v is done", i) f.l.Infof("underlay reader %v is done", i)
@@ -310,7 +334,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
if err != nil { if err != nil {
if !f.closed.Load() { if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close f.onFatal(err)
} }
break break
} }
@@ -505,5 +529,7 @@ func (f *Interface) Close() error {
} }
// Release the tun device (closing the tun also closes all readers) // Release the tun device (closing the tun also closes all readers)
return f.inside.Close() err = f.inside.Close()
f.wg.Done()
return err
} }

View File

@@ -831,18 +831,18 @@ func (t *tun) Close() error {
} }
err := t.readers[i].Close() err := t.readers[i].Close()
if err != nil { if err != nil {
t.l.WithField("reader", i).WithError(err).Error("Error closing tun reader") t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
} else { } else {
t.l.WithField("reader", i).Info("Closed tun reader") t.l.WithField("reader", i).Info("closed tun reader")
} }
} }
//this is t.readers[0] too //this is t.readers[0] too
err := t.tunFile.Close() err := t.tunFile.Close()
if err != nil { if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("Error closing tun reader") t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
} else { } else {
t.l.WithField("reader", 0).Info("Closed tun reader") t.l.WithField("reader", 0).Info("closed tun reader")
} }
return err return err
} }

View File

@@ -144,10 +144,10 @@ func New(control *nebula.Control) (*Service, error) {
} }
}) })
// Add the nebula wait function to the group // Add the nebula wait function to the group so a fatal reader error
// propagates out through errgroup.Wait().
eg.Go(func() error { eg.Go(func() error {
wait() return wait()
return nil
}) })
return &s, nil return &s, nil