Remove more os.Exit calls and give a more reliable wait for stop function

This commit is contained in:
Nate Brown
2025-04-02 09:51:59 -05:00
committed by JackDoan
parent f77fe74192
commit 6592a07b51
13 changed files with 161 additions and 81 deletions

View File

@@ -78,8 +78,16 @@ func main() {
} }
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
ctrl.ShutdownBlock() if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -72,9 +72,17 @@ func main() {
} }
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l) notifyReady(l)
ctrl.ShutdownBlock() wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -2,9 +2,11 @@ package nebula
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
) )
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
} }
type Control struct { type Control struct {
stateLock sync.Mutex
state RunState
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
ctx context.Context ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call.
func (c *Control) Start() { // The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface // Activate the interface
c.f.activate() err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created. // Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil { if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
} }
// Start reading packets. // Start reading packets.
c.f.run() c.state = Started
c.stateLock.Unlock()
return c.f.run()
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
} }
func (c *Control) Context() context.Context { func (c *Control) Context() context.Context {
return c.ctx return c.ctx
} }
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete // Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() { func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from // Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down. // being created while we're shutting them all down.
c.cancel() c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil { if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed") c.l.WithError(err).Error("Close interface failed")
} }
c.l.Info("Goodbye") c.state = Stopped
} }
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"net/netip" "net/netip"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -87,6 +88,7 @@ type Interface struct {
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
@@ -209,7 +211,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any // activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However, // other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called. // the interface isn't going to process anything until run() is called.
func (f *Interface) activate() { func (f *Interface) activate() error {
// actually turn on tun dev // actually turn on tun dev
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
@@ -237,28 +239,36 @@ func (f *Interface) activate() {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
if err != nil { if err != nil {
f.l.Fatal(err) return err
} }
} }
f.readers[i] = reader f.readers[i] = reader
} }
if err := f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
f.inside.Close() f.inside.Close()
f.l.Fatal(err) return err
} }
return nil
} }
func (f *Interface) run() { func (f *Interface) run() (func(), error) {
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenOut(i) f.wg.Go(func() {
f.listenOut(i)
})
} }
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i) f.wg.Go(func() {
f.listenIn(f.readers[i], i)
})
} }
return f.wg.Wait, nil
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
@@ -276,9 +286,17 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
}) })
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -292,17 +310,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if !f.closed.Load() {
return f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
} }
break
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
} }
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
} }
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -479,15 +498,18 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers { for _, u := range f.writers {
err := u.Close() err := u.Close()
if err != nil { if err != nil {
f.l.WithError(err).Error("Error while closing udp socket") f.l.WithError(err).Error("Error while closing udp socket")
} }
} }
// Release the tun readers
for i, r := range f.readers { for i, r := range f.readers {
if i == 0 { if i == 0 {
continue // f.readers[0] is f.inside, which we want to save for last continue // f.readers[0] is f.inside, which we want to save for last, since it closes other stuff too
} }
if err := r.Close(); err != nil { if err := r.Close(); err != nil {
f.l.WithError(err).Error("Error while closing tun reader") f.l.WithError(err).Error("Error while closing tun reader")

18
main.go
View File

@@ -288,15 +288,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
return &Control{ return &Control{
ifce, f: ifce,
l, l: l,
ctx, ctx: ctx,
cancel, cancel: cancel,
sshStart, sshStart: sshStart,
statsStart, statsStart: statsStart,
dnsStart, dnsStart: dnsStart,
lightHouse.StartUpdateWorker, lighthouseStart: lightHouse.StartUpdateWorker,
connManager.Start, connectionManagerStart: connManager.Start,
}, nil }, nil
} }

View File

@@ -686,16 +686,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
func (t *tun) Close() error { func (t *tun) Close() error {
if t.routeChan != nil { if t.routeChan != nil {
close(t.routeChan) close(t.routeChan)
} t.routeChan = nil
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
} }
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close() err := os.NewFile(t.ioctlFd, "ioctlFd").Close()
if err != nil {
t.l.WithField("error", err).Error("Failed to close ioctl fd")
}
t.ioctlFd = 0 t.ioctlFd = 0
} }
if t.ReadWriteCloser != nil {
err := t.ReadWriteCloser.Close()
if err != nil {
t.l.WithField("error", err).Error("Failed to close tun file")
return err
}
t.ReadWriteCloser = nil
}
return nil return nil
} }

View File

@@ -44,7 +44,10 @@ type Service struct {
} }
func New(control *nebula.Control) (*Service, error) { func New(control *nebula.Control) (*Service, error) {
control.Start() wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context() ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
} }
}) })
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil return &s, nil
} }

View File

@@ -16,7 +16,7 @@ type EncReader func(
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool SupportsMultipleReaders() bool
@@ -31,8 +31,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) error {
return return nil
} }
func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) SupportsMultipleReaders() bool {
return false return false

View File

@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {} return func() {}
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
@@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) {
n, rua, err := u.ReadFromUDPAddrPort(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
u.l.WithError(err).Error("unexpected udp socket receive error") u.l.WithError(err).Error("unexpected udp socket receive error")

View File

@@ -10,11 +10,9 @@ package udp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@@ -73,25 +71,14 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
var lastRecvErr time.Time
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
// Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
lastRecvErr = time.Now()
u.l.WithError(err).Warn("unexpected udp socket receive error")
}
continue
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])

View File

@@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
return int(n), true, nil return int(n), true, nil
} }
func (u *StdConn) listenOutSingle(r EncReader) { func (u *StdConn) listenOutSingle(r EncReader) error {
var err error var err error
var n int var n int
var from netip.AddrPort var from netip.AddrPort
@@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) {
for { for {
n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
r(from, buffer[:n]) r(from, buffer[:n])
} }
} }
func (u *StdConn) listenOutBatch(r EncReader) { func (u *StdConn) listenOutBatch(r EncReader) error {
var ip netip.Addr var ip netip.Addr
var n int var n int
var operr error var operr error
@@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) {
for { for {
err := u.rawConn.Read(reader) err := u.rawConn.Read(reader)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
if operr != nil { if operr != nil {
u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") return operr
return
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) {
} }
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) error {
if u.batch == 1 { if u.batch == 1 {
//save some ram by not calling PrepareRawMessages for fields we won't use return u.listenOutSingle(r)
//we could also make this path more common by calling recvmmsg with msgs[:1],
//but that's still the recvmmsg syscall, which would be a change
u.listenOutSingle(r)
} else { } else {
u.listenOutBatch(r) return u.listenOutBatch(r)
} }
} }
@@ -290,7 +284,7 @@ func (u *StdConn) ReloadConfig(c *config.C) {
} }
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
var vallen uint32 = 4 * unix.SK_MEMINFO_VARS const vallen uint32 = 4 * unix.SK_MEMINFO_VARS
if u.rawConn == nil { if u.rawConn == nil {
return fmt.Errorf("no UDP connection") return fmt.Errorf("no UDP connection")

View File

@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader) { func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
var lastRecvErr time.Time var lastRecvErr time.Time
@@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
// Dampen unexpected message warns to once per minute // Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {

View File

@@ -6,6 +6,7 @@ package udp
import ( import (
"io" "io"
"net/netip" "net/netip"
"os"
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
func (u *TesterConn) ListenOut(r EncReader) { func (u *TesterConn) ListenOut(r EncReader) error {
for { for {
p, ok := <-u.RxPackets p, ok := <-u.RxPackets
if !ok { if !ok {
return return os.ErrClosed
} }
r(p.From, p.Data) r(p.From, p.Data)
} }