Plug the conntrack cache ticker leak and nebula-service log.Fatal calls (#1669)

This commit is contained in:
Nate Brown
2026-04-21 13:19:54 -05:00
committed by GitHub
parent 2f4532f102
commit 8c50fc3f60
4 changed files with 32 additions and 22 deletions

View File

@@ -50,10 +50,16 @@ func main() {
os.Exit(0) os.Exit(0)
} }
l := logrus.New()
l.Out = os.Stdout
if *serviceFlag != "" { if *serviceFlag != "" {
doService(configPath, configTest, Build, serviceFlag) if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
l.WithError(err).Error("Service command failed")
os.Exit(1) os.Exit(1)
} }
return
}
if *configPath == "" { if *configPath == "" {
fmt.Println("-config flag must be set") fmt.Println("-config flag must be set")
@@ -61,9 +67,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
l := logrus.New()
l.Out = os.Stdout
c := config.NewC(l) c := config.NewC(l)
err := c.Load(*configPath) err := c.Load(*configPath)
if err != nil { if err != nil {

View File

@@ -57,11 +57,11 @@ func fileExists(filename string) bool {
return true return true
} }
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" { if *configPath == "" {
ex, err := os.Executable() ex, err := os.Executable()
if err != nil { if err != nil {
panic(err) return err
} }
*configPath = filepath.Dir(ex) + "/config.yaml" *configPath = filepath.Dir(ex) + "/config.yaml"
if !fileExists(*configPath) { if !fileExists(*configPath) {
@@ -88,13 +88,13 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
s, err := service.New(prg, svcConfig) s, err := service.New(prg, svcConfig)
if err != nil { if err != nil {
log.Fatal(err) return err
} }
errs := make(chan error, 5) errs := make(chan error, 5)
logger, err = s.Logger(errs) logger, err = s.Logger(errs)
if err != nil { if err != nil {
log.Fatal(err) return err
} }
go func() { go func() {
@@ -109,18 +109,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
switch *serviceFlag { switch *serviceFlag {
case "run": case "run":
err = s.Run() if err := s.Run(); err != nil {
if err != nil {
// Route any errors to the system logger // Route any errors to the system logger
logger.Error(err) logger.Error(err)
} }
default: default:
err := service.Control(s, *serviceFlag) if err := service.Control(s, *serviceFlag); err != nil {
if err != nil {
log.Printf("Valid actions: %q\n", service.ControlAction) log.Printf("Valid actions: %q\n", service.ControlAction)
log.Fatal(err) return err
} }
return
} }
return nil
} }

View File

@@ -1,6 +1,7 @@
package firewall package firewall
import ( import (
"context"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -18,7 +19,7 @@ type ConntrackCacheTicker struct {
cache ConntrackCache cache ConntrackCache
} }
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker {
if d == 0 { if d == 0 {
return nil return nil
} }
@@ -27,16 +28,22 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
cache: ConntrackCache{}, cache: ConntrackCache{},
} }
go c.tick(d) go c.tick(ctx, d)
return c return c
} }
func (c *ConntrackCacheTicker) tick(d time.Duration) { func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
t := time.NewTicker(d)
defer t.Stop()
for { for {
time.Sleep(d) select {
case <-ctx.Done():
return
case <-t.C:
c.cacheTick.Add(1) c.cacheTick.Add(1)
} }
}
} }
// Get checks if the cache ticker has moved to the next version before returning // Get checks if the cache ticker has moved to the next version before returning

View File

@@ -85,6 +85,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
ctx context.Context
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup wg sync.WaitGroup
@@ -170,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
cs := c.pki.getCertState() cs := c.pki.getCertState()
ifce := &Interface{ ifce := &Interface{
ctx: ctx,
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
outside: c.Outside, outside: c.Outside,
@@ -303,7 +305,7 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU) plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
@@ -328,7 +330,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)