diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 8d0eaa1..fce040f 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -65,8 +65,16 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 5cf0a02..93f3967 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -59,9 +59,17 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/control.go b/control.go index 20dd7fe..d694eae 100644 --- a/control.go +++ b/control.go @@ -2,9 +2,11 @@ package nebula import ( "context" + "errors" "net/netip" "os" "os/signal" + "sync" "syscall" "github.com/sirupsen/logrus" @@ -13,6 +15,16 @@ import ( "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + Stopped RunState = 0 // The control has yet to be started + Started RunState = 1 // The control has been started + Stopping RunState = 2 // The control is stopping +) + +var ErrAlreadyStarted = errors.New("nebula is already started") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,6 +38,9 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface l *logrus.Logger ctx context.Context @@ -48,10 +63,21 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function can be used to wait for nebula to fully stop. +func (c *Control) Start() (func(), error) { + c.stateLock.Lock() + if c.state != Stopped { + c.stateLock.Unlock() + return nil, ErrAlreadyStarted + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.stateLock.Unlock() + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -68,15 +94,33 @@ func (c *Control) Start() { } // Start reading packets. - c.f.run() + c.state = Started + c.stateLock.Unlock() + return c.f.run() +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != Started { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = Stopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() @@ -85,7 +129,7 @@ func (c *Control) Stop() { if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } - c.l.Info("Goodbye") + c.state = Stopped } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled diff --git a/interface.go b/interface.go index 21e198c..5327548 100644 --- a/interface.go +++ b/interface.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/netip" - "os" "runtime" + "sync" "sync/atomic" "time" @@ -87,6 +87,7 @@ type Interface struct { writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -206,7 +207,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -227,28 +228,34 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + if err = f.inside.Activate(); err != nil { f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func(), error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { go f.listenOut(i) + f.wg.Add(1) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { go f.listenIn(f.readers[i], i) + f.wg.Add(1) } + + return f.wg.Wait, nil } func (f *Interface) listenOut(i int) { @@ -271,6 +278,8 @@ func (f *Interface) listenOut(i int) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + + f.wg.Done() } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -286,17 +295,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.WithError(err).Error("Error while reading outbound packet") } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + + f.wg.Done() } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { diff --git a/main.go b/main.go index b278fa6..d8bd2b2 100644 --- a/main.go +++ b/main.go @@ -288,13 +288,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: statsStart, + dnsStart: dnsStart, + lighthouseStart: lightHouse.StartUpdateWorker, }, nil } diff --git a/service/service.go b/service/service.go index 4339677..16c244b 100644 --- a/service/service.go +++ b/service/service.go @@ -54,7 +54,11 @@ func New(config *config.C) (*Service, error) { if err != nil { return nil, err } - control.Start() + + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -151,6 +155,12 @@ func New(config *config.C) (*Service, error) { } }) + // Add the nebula wait function to the group + eg.Go(func() error { + wait() + return nil + }) + return &s, nil }