From 0b02d982b256dffc9c215306a2e550d8a1bd16ab Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 21 Jan 2026 12:42:34 -0500 Subject: [PATCH 01/44] v1.10.2 (#1584) Update CHANGELOG for Nebula v1.10.2 --- CHANGELOG.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 104b52e3..330a7f78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.10.2] - 2026-01-21 + +### Fixed + +- Fix panic when using `use_system_route_table` that was introduced in v1.10.1. (#1580) + +### Changed + +- Fix some typos in comments. (#1582) +- Dependency updates. (#1581) + ## [1.10.1] - 2026-01-16 See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes. @@ -764,7 +775,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.1...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.2...HEAD +[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2 [1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1 [1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 [1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7 From 02d8bcac68ccf01f992c89bd89c8d1c7b9670945 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Tue, 27 Jan 2026 23:44:43 -0600 Subject: [PATCH 02/44] Remove lighthouse goroutine leaks in lighthouse_test.go (#1589) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Using + Claude, I was able to run nebula's unit tests and e2e tests with the leak detector enabled. Added a TestMain that queries pprof to see if there are any reported goroutine leaks. I'd love to get some form of this in CI whenever go 1.26 comes out, though I'd also like to prove this is properly useful past the just five detections it got here.
TestMain ```go package nebula import ( "fmt" "os" "runtime/pprof" "strings" "testing" ) // TestMain runs after all tests and checks for goroutine leaks func TestMain(m *testing.M) { // Run all tests exitCode := m.Run() // Check for goroutine leaks after all tests complete prof := pprof.Lookup("goroutineleak") if prof != nil { var sb strings.Builder if err := prof.WriteTo(&sb, 2); err != nil { fmt.Fprintf(os.Stderr, "Failed to write goroutineleak profile: %v\n", err) os.Exit(1) } content := sb.String() leakedCount := strings.Count(content, "(leaked)") if leakedCount > 0 { fmt.Fprintf(os.Stderr, "\n=== GOROUTINE LEAK DETECTED ===\n") fmt.Fprintf(os.Stderr, "Found %d leaked goroutine(s) in package nebula\n\n", leakedCount) goros := strings.Split(content, "\n\n") for _, goro := range goros { if strings.Contains(goro, "(leaked)") { fmt.Fprintln(os.Stderr, goro) fmt.Fprintln(os.Stderr) } } os.Exit(1) } else { fmt.Println("✓ No goroutine leaks detected in package nebula") } } os.Exit(exitCode) } ```
Also had to install go1.26rc2 and update the makefile to use that go binary + set ex: ```makefile test-goroutineleak: GOEXPERIMENT=goroutineleakprofile go1.26rc2 test -v ./... ``` --- lighthouse_test.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lighthouse_test.go b/lighthouse_test.go index fea1d1ed..c57c44ec 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -1,7 +1,6 @@ package nebula import ( - "context" "encoding/binary" "fmt" "net/netip" @@ -42,14 +41,14 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err = NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -71,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(b.Context(), l, c, cs, nil, nil) require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") @@ -202,7 +201,7 @@ func TestLighthouse_Memory(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} require.NoError(t, err) lhh := lh.NewRequestHandler() @@ -288,7 +287,7 @@ func TestLighthouse_reload(t *testing.T) { myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) nc := map[string]any{ @@ -523,7 +522,7 @@ func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -589,7 +588,7 @@ func TestLighthouse_DeletesWork(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} From 42bee7cf179f913e00d6964eeffd396784f2a17b Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 28 Jan 2026 10:03:36 -0600 Subject: [PATCH 03/44] Report if Nebula start fails because of tun device name (#1588) * specifically report if nebula start fails because of tun device name * close all routines when closing the tun --- interface.go | 8 ++++++++ overlay/tun.go | 9 +++++++++ overlay/tun_freebsd.go | 2 +- overlay/tun_linux.go | 9 +++++++-- overlay/tun_windows.go | 5 ++++- 5 files changed, 29 insertions(+), 4 deletions(-) diff --git a/interface.go b/interface.go index f69ed062..61b1f228 100644 --- a/interface.go +++ b/interface.go @@ -490,6 +490,14 @@ func (f *Interface) Close() error { f.l.WithError(err).Error("Error while closing udp socket") } } + for i, r := range f.readers { + if i == 0 { + continue // f.readers[0] is f.inside, which we want to save for last + } + if err := r.Close(); err != nil { + f.l.WithError(err).Error("Error while closing tun reader") + } + } // Release the tun device return f.inside.Close() diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..e0bf69f6 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -12,6 +12,15 @@ import ( const DefaultMTU = 1300 +type NameError struct { + Name string + Underlying error +} + +func (e *NameError) Error() string { + return fmt.Sprintf("could not set tun device name: %s because %s", e.Name, e.Underlying) +} + // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8d292263..2f65b3a4 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -266,7 +266,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } // Set the device name - ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + _ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } t := &tun{ diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index ea666f86..7e4aa418 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -112,9 +112,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } - copy(req.Name[:], c.GetString("tun.dev", "")) + nameStr := c.GetString("tun.dev", "") + copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - return nil, err + return nil, &NameError{ + Name: nameStr, + Underlying: err, + } } name := strings.Trim(string(req.Name[:]), "\x00") @@ -713,6 +717,7 @@ func (t *tun) Close() error { if t.ioctlFd > 0 { _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + t.ioctlFd = 0 } return nil diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index b4d78b66..223eabee 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -74,7 +74,10 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( l.WithError(err).Debug("Failed to create wintun device, retrying") tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) + return nil, &NameError{ + Name: deviceName, + Underlying: fmt.Errorf("create TUN device failed: %w", err), + } } } t.tun = tunDevice.(*wintun.NativeTun) From f573e8a26695278f9d71587390fbfe0d0933aa21 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 6 Feb 2026 13:26:51 -0600 Subject: [PATCH 04/44] Merge commit from fork Newly signed P256 based certificates will have their signature clamped to the low-s form. Update CHANGELOG.md --- CHANGELOG.md | 16 +++++- cert/ca_pool.go | 18 ++++++ cert/ca_pool_test.go | 47 ++++++++++++++-- cert/cert.go | 33 +++++++++++ cert/p256/p256.go | 122 +++++++++++++++++++++++++++++++++++++++++ cert/p256/p256_test.go | 28 ++++++++++ cert/sign.go | 9 +++ cert/sign_test.go | 46 ++++++++++++++++ go.mod | 1 + go.sum | 2 + 10 files changed, 317 insertions(+), 5 deletions(-) create mode 100644 cert/p256/p256.go create mode 100644 cert/p256/p256_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 330a7f78..2ef7551f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.10.3] - 2026-02-06 + +### Security + +- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations. + Both fingerprint representations will be tested against the blocklist. + Any newly issued P256 based certificates will have their signature clamped to the low-s form. + Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962) + +### Changed + +- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588) + ## [1.10.2] - 2026-01-21 ### Fixed @@ -775,7 +788,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.2...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD +[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3 [1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2 [1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1 [1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 2bf480f2..e9903e1f 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -141,10 +141,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti return nil, err } + // Pre nebula v1.10.3 could generate signatures in either high or low s form and validation + // of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form + // but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we + // need to test both forms until such a time comes that we enforce low-s form on signature validation. + fp2, err := CalculateAlternateFingerprint(c) + if err != nil { + return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err) + } + if fp2 != "" && ncp.IsBlocklisted(fp2) { + return nil, ErrBlockListed + } + cc := CachedCertificate{ Certificate: c, InvertedGroups: make(map[string]struct{}), Fingerprint: fp, + fingerprint2: fp2, signerFingerprint: signer.Fingerprint, } @@ -158,6 +171,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and // is a cheaper operation to perform as a result. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + // Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s + if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) { + return ErrBlockListed + } + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) return err } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index b0fdd5fb..e872c7d4 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -170,6 +171,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV1).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -187,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -196,7 +206,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } @@ -394,6 +414,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV2).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -411,7 +440,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -420,7 +449,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } diff --git a/cert/cert.go b/cert/cert.go index 9d40e625..01d775e5 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) type Version uint8 @@ -110,6 +112,9 @@ type CachedCertificate struct { InvertedGroups map[string]struct{} Fingerprint string signerFingerprint string + + // A place to store a 2nd fingerprint if the certificate could have one, such as with P256 + fingerprint2 string } func (cc *CachedCertificate) String() string { @@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific return c, nil } + +// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates +// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step. +func CalculateAlternateFingerprint(c Certificate) (string, error) { + if c.Curve() != Curve_P256 { + return "", nil + } + + nc := c.Copy() + b, err := p256.Swap(nc.Signature()) + if err != nil { + return "", err + } + + switch v := nc.(type) { + case *certificateV1: + err = v.setSignature(b) + case *certificateV2: + err = v.setSignature(b) + default: + return "", ErrUnknownVersion + } + + if err != nil { + return "", err + } + return nc.Fingerprint() +} diff --git a/cert/p256/p256.go b/cert/p256/p256.go new file mode 100644 index 00000000..be0a2381 --- /dev/null +++ b/cert/p256/p256.go @@ -0,0 +1,122 @@ +package p256 + +import ( + "crypto/elliptic" + "errors" + "math/big" + + "filippo.io/bigmod" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1) +var nMod *bigmod.Modulus + +func init() { + n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes()) + if err != nil { + panic(err) + } + nMod = n +} + +func IsNormalized(sig []byte) (bool, error) { + r, s, err := parseSignature(sig) + if err != nil { + return false, err + } + return checkLowS(r, s), nil +} + +func checkLowS(_, s []byte) bool { + bigS := new(big.Int).SetBytes(s) + // Check if S <= (N/2), because we want to include the midpoint in the set of low-s + return bigS.Cmp(halfN) <= 0 +} + +func swap(r, s []byte) ([]byte, []byte, error) { + var err error + bigS, err := bigmod.NewNat().SetBytes(s, nMod) + if err != nil { + return nil, nil, err + } + sNormalized := nMod.Nat().Sub(bigS, nMod) + + return r, sNormalized.Bytes(nMod), nil +} + +func Normalize(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + if checkLowS(r, s) { + return sig, nil + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// Swap will change sig between its current form to the opposite high or low form. +func Swap(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// parseSignature taken exactly from crypto/ecdsa/ecdsa.go +func parseSignature(sig []byte) (r, s []byte, err error) { + var inner cryptobyte.String + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(&r) || + !inner.ReadASN1Integer(&s) || + !inner.Empty() { + return nil, nil, errors.New("invalid ASN.1") + } + return r, s, nil +} + +func encodeSignature(r, s []byte) ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + addASN1IntBytes(b, r) + addASN1IntBytes(b, s) + }) + return b.Bytes() +} + +// addASN1IntBytes encodes in ASN.1 a positive integer represented as +// a big-endian byte slice with zero or more leading zeroes. +func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { + for len(bytes) > 0 && bytes[0] == 0 { + bytes = bytes[1:] + } + if len(bytes) == 0 { + b.SetError(errors.New("invalid integer")) + return + } + b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) { + if bytes[0]&0x80 != 0 { + c.AddUint8(0) + } + c.AddBytes(bytes) + }) +} diff --git a/cert/p256/p256_test.go b/cert/p256/p256_test.go new file mode 100644 index 00000000..486a7242 --- /dev/null +++ b/cert/p256/p256_test.go @@ -0,0 +1,28 @@ +package p256 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlipping(t *testing.T) { + priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err1) + + out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus")) + require.NoError(t, err) + + r, s, err := parseSignature(out) + require.NoError(t, err) + + r, s1, err := swap(r, s) + require.NoError(t, err) + r, s2, err := swap(r, s1) + require.NoError(t, err) + require.Equal(t, s, s2) + require.NotEqual(t, s, s1) +} diff --git a/cert/sign.go b/cert/sign.go index 3eb08592..fbfffe4e 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -9,6 +9,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) // TBSCertificate represents a certificate intended to be signed. @@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb return nil, err } + if curve == Curve_P256 { + sig, err = p256.Normalize(sig) + if err != nil { + return nil, err + } + } + err = c.setSignature(sig) if err != nil { return nil, err diff --git a/cert/sign_test.go b/cert/sign_test.go index e6f43cdf..bf4c9c0d 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) { require.NoError(t, err) assert.NotNil(t, uc) } + +func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: true, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + for i := 0; i < 1000; i++ { + if i&1 == 1 { + tbs.Version = Version1 + } else { + tbs.Version = Version2 + } + c, err := tbs.Sign(nil, Curve_P256, rawPriv) + require.NoError(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + normie, err := p256.IsNormalized(c.Signature()) + require.NoError(t, err) + assert.True(t, normie) + } +} diff --git a/go.mod b/go.mod index 1c564d03..f302f928 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( dario.cat/mergo v1.0.2 + filippo.io/bigmod v0.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 diff --git a/go.sum b/go.sum index c4613e01..f4b1074c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8= +filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= From 353ad1f27193b92d3bb7d6696cac5239fa1728e5 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 13 Feb 2026 11:10:40 -0600 Subject: [PATCH 05/44] firewall: icmp no longer requires a port spec (#1609) --- examples/config.yml | 2 +- firewall.go | 138 +++++++++++++++++++++++++------------------- firewall_test.go | 20 +++++-- 3 files changed, 95 insertions(+), 65 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index f81baab6..1f9dc2a4 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -382,8 +382,8 @@ firewall: # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). - # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` + # a port specification is ignored if proto is `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass diff --git a/firewall.go b/firewall.go index 45dc0691..72119e0e 100644 --- a/firewall.go +++ b/firewall.go @@ -249,20 +249,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { - // We need this rule string because we generate a hash. Removing this will break firewall reload. - ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, - ) - f.rules += ruleString + "\n" - - direction := "incoming" - if !incoming { - direction = "outgoing" - } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") - var ( ft *FirewallTable fp firewallPort @@ -280,6 +266,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoUDP: fp = ft.UDP case firewall.ProtoICMP, firewall.ProtoICMPv6: + //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided + if startPort != firewall.PortAny { + f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") + } + startPort = firewall.PortAny + endPort = firewall.PortAny fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -287,6 +279,20 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } + // We need this rule string because we generate a hash. Removing this will break firewall reload. + ruleString := fmt.Sprintf( + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, + ) + f.rules += ruleString + "\n" + + direction := "incoming" + if !incoming { + direction = "outgoing" + } + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). + Info("Firewall rule added") + return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -349,24 +355,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw sPort = r.Port } - startPort, endPort, err := parsePort(sPort) - if err != nil { - return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) - } - var proto uint8 + var startPort, endPort int32 switch r.Proto { case "any": proto = firewall.ProtoAny + startPort, endPort, err = parsePort(sPort) case "tcp": proto = firewall.ProtoTCP + startPort, endPort, err = parsePort(sPort) case "udp": proto = firewall.ProtoUDP + startPort, endPort, err = parsePort(sPort) case "icmp": proto = firewall.ProtoICMP + startPort = firewall.PortAny + endPort = firewall.PortAny + if sPort != "" { + l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") + } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } + if err != nil { + return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) + } if r.Cidr != "" && r.Cidr != "any" { _, err = netip.ParsePrefix(r.Cidr) @@ -660,6 +673,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer return false } + // this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match + if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 { + // port numbers are re-used for connection tracking of ICMP, + // but we don't want to actually filter on them. + return fp[firewall.PortAny].match(p, c, caPool) + } + var port int32 if p.Fragment { @@ -1018,54 +1038,56 @@ func (r *rule) sanity() error { } } + if r.Code != "" { + return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code) + } + //todo alert on cidr-any return nil } -func parsePort(s string) (startPort, endPort int32, err error) { +func parsePort(s string) (int32, int32, error) { + var err error + const notAPort int32 = -2 if s == "any" { - startPort = firewall.PortAny - endPort = firewall.PortAny - - } else if s == "fragment" { - startPort = firewall.PortFragment - endPort = firewall.PortFragment - - } else if strings.Contains(s, `-`) { - sPorts := strings.SplitN(s, `-`, 2) - sPorts[0] = strings.Trim(sPorts[0], " ") - sPorts[1] = strings.Trim(sPorts[1], " ") - - if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { - return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) - } - - rStartPort, err := strconv.Atoi(sPorts[0]) - if err != nil { - return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) - } - - rEndPort, err := strconv.Atoi(sPorts[1]) - if err != nil { - return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) - } - - startPort = int32(rStartPort) - endPort = int32(rEndPort) - - if startPort == firewall.PortAny { - endPort = firewall.PortAny - } - - } else { + return firewall.PortAny, firewall.PortAny, nil + } + if s == "fragment" { + return firewall.PortFragment, firewall.PortFragment, nil + } + if !strings.Contains(s, `-`) { rPort, err := strconv.Atoi(s) if err != nil { - return 0, 0, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) } - startPort = int32(rPort) - endPort = startPort + return int32(rPort), int32(rPort), nil } - return + sPorts := strings.SplitN(s, `-`, 2) + for i := range sPorts { + sPorts[i] = strings.Trim(sPorts[i], " ") + } + if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { + return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) + } + + rStartPort, err := strconv.Atoi(sPorts[0]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + } + + rEndPort, err := strconv.Atoi(sPorts[1]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + } + + startPort := int32(rStartPort) + endPort := int32(rEndPort) + + if startPort == firewall.PortAny { + endPort = firewall.PortAny + } + + return startPort, endPort, nil } diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..934a90a4 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -87,9 +87,10 @@ func TestFirewall_AddRule(t *testing.T) { fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) - assert.Nil(t, fw.InRules.ICMP[1].Any.Any) - assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) - assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + //no matter what port is given for icmp, it should end up as "any" + assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) + assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) @@ -919,11 +920,11 @@ func TestNewFirewallFromConfig(t *testing.T) { // Test code/port error conf = config.NewC(l) - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") @@ -973,7 +974,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + + // Test adding icmp rule no port + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule conf = config.NewC(l) From e8bb874e14e2c8dbbc18b678a54c731eefbdb6b4 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 13 Feb 2026 13:55:19 -0500 Subject: [PATCH 06/44] smoke-extra: try AMD-V workaround (#1610) * smoke-extra: try AMD-V workaround - https://github.com/slackhq/nebula/actions/runs/21995850645/job/63555492676?pr=1602 - https://github.com/actions/runner-images/issues/13202 - https://github.com/cri-o/packaging/pull/306/changes --- .github/workflows/smoke-extra.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 24f899ab..cdd6ea9d 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -30,6 +30,9 @@ jobs: - name: add hashicorp source run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + - name: workaround AMD-V issue # https://github.com/cri-o/packaging/pull/306 + run: sudo rmmod kvm_amd + - name: install vagrant run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox From 422fc2ad1e6e5e47d13b846842f789c515d36493 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 17 Feb 2026 11:42:14 -0500 Subject: [PATCH 07/44] go fix (#1608) --- boring.go | 1 - control_tester.go | 1 - firewall.go | 6 ++---- handshake_manager.go | 2 +- hostmap_tester.go | 1 - inside_bsd.go | 1 - inside_generic.go | 1 - lighthouse.go | 13 +++---------- main.go | 6 +----- notboring.go | 1 - outside_test.go | 2 +- relay_manager.go | 2 +- remote_list.go | 7 +------ ssh.go | 5 ++--- timeout_test.go | 4 ++-- 15 files changed, 14 insertions(+), 39 deletions(-) diff --git a/boring.go b/boring.go index 9cd9d37f..abe403fc 100644 --- a/boring.go +++ b/boring.go @@ -1,5 +1,4 @@ //go:build boringcrypto -// +build boringcrypto package nebula diff --git a/control_tester.go b/control_tester.go index 7403a745..f927140b 100644 --- a/control_tester.go +++ b/control_tester.go @@ -1,5 +1,4 @@ //go:build e2e_testing -// +build e2e_testing package nebula diff --git a/firewall.go b/firewall.go index 72119e0e..da819221 100644 --- a/firewall.go +++ b/firewall.go @@ -824,10 +824,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { return true } - for _, group := range groups { - if group == "any" { - return true - } + if slices.Contains(groups, "any") { + return true } if host == "any" { diff --git a/handshake_manager.go b/handshake_manager.go index 8b1ce839..25a59b6d 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -590,7 +590,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(hm.l) if err != nil { return err diff --git a/hostmap_tester.go b/hostmap_tester.go index fe40c533..a6ac6d44 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -1,5 +1,4 @@ //go:build e2e_testing -// +build e2e_testing package nebula diff --git a/inside_bsd.go b/inside_bsd.go index c9c7730d..dc847878 100644 --- a/inside_bsd.go +++ b/inside_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package nebula diff --git a/inside_generic.go b/inside_generic.go index 0bb2345a..bdcc1a6a 100644 --- a/inside_generic.go +++ b/inside_generic.go @@ -1,5 +1,4 @@ //go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd -// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd package nebula diff --git a/lighthouse.go b/lighthouse.go index 1510b942..36eb9aa0 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -713,21 +713,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { l := lh.GetLighthouses() - for i := range l { - if l[i] == vpnAddr { - return true - } - } - return false + return slices.Contains(l, vpnAddr) } func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() for i := range vpnAddrs { - for j := range l { - if l[j] == vpnAddrs[i] { - return true - } + if slices.Contains(l, vpnAddrs[i]) { + return true } } return false diff --git a/main.go b/main.go index 17aaa548..74979417 100644 --- a/main.go +++ b/main.go @@ -105,11 +105,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // deprecated and undocumented tunQueues := c.GetInt("tun.routines", 1) udpQueues := c.GetInt("listen.routines", 1) - if tunQueues > udpQueues { - routines = tunQueues - } else { - routines = udpQueues - } + routines = max(tunQueues, udpQueues) if routines != 1 { l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") } diff --git a/notboring.go b/notboring.go index c86b0bc3..f138a0a6 100644 --- a/notboring.go +++ b/notboring.go @@ -1,5 +1,4 @@ //go:build !boringcrypto -// +build !boringcrypto package nebula diff --git a/outside_test.go b/outside_test.go index 38dbef62..2a750d40 100644 --- a/outside_test.go +++ b/outside_test.go @@ -574,7 +574,7 @@ func BenchmarkParseV6(b *testing.B) { } evilBytes := buffer.Bytes() - for i := 0; i < 200; i++ { + for range 200 { evilBytes = append(evilBytes, hopHeader...) } evilBytes = append(evilBytes, lastHopHeader...) diff --git a/relay_manager.go b/relay_manager.go index 5dd355ca..91640f24 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -55,7 +55,7 @@ func (rm *relayManager) setAmRelay(v bool) { func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(l) if err != nil { return 0, err diff --git a/remote_list.go b/remote_list.go index 1304fd51..8338d517 100644 --- a/remote_list.go +++ b/remote_list.go @@ -404,12 +404,7 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { - for _, v := range r.badRemotes { - if v == remote { - return true - } - } - return false + return slices.Contains(r.badRemotes, remote) } // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the diff --git a/ssh.go b/ssh.go index 9a26c290..0a9adb51 100644 --- a/ssh.go +++ b/ssh.go @@ -6,6 +6,7 @@ import ( "errors" "flag" "fmt" + "maps" "net" "net/netip" "os" @@ -831,9 +832,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er relays := map[uint32]*HostInfo{} ifce.hostMap.Lock() - for k, v := range ifce.hostMap.Relays { - relays[k] = v - } + maps.Copy(relays, ifce.hostMap.Relays) ifce.hostMap.Unlock() type RelayFor struct { diff --git a/timeout_test.go b/timeout_test.go index db36fec7..ffeecc55 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -134,7 +134,7 @@ func TestTimerWheel_Purge(t *testing.T) { assert.True(t, tw.lastTick.After(lastTick)) // Make sure we get all 4 packets back - for i := 0; i < 4; i++ { + for i := range 4 { p, has := tw.Purge() assert.True(t, has) assert.Equal(t, fps[i], p) @@ -149,7 +149,7 @@ func TestTimerWheel_Purge(t *testing.T) { // Make sure we cached the free'd items assert.Equal(t, 4, tw.itemsCached) ci := tw.itemCache - for i := 0; i < 4; i++ { + for range 4 { assert.NotNil(t, ci) ci = ci.Next } From 51308b845b6fc1bf19dd522db3ec0f22011e7617 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 18 Feb 2026 23:19:37 -0600 Subject: [PATCH 08/44] connection-track ICMP traffic (#1602) * connection-track ICMP and ICMPv6 traffic * icmpv6 only has identifier on echo --- .github/workflows/smoke/smoke.sh | 19 ++-- firewall.go | 2 +- firewall/packet.go | 7 +- firewall_test.go | 144 +++++++++++++++++++++++++++++++ outside.go | 58 +++++++++---- outside_test.go | 25 ++++-- 6 files changed, 216 insertions(+), 39 deletions(-) diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 6d04027a..66164921 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -37,17 +37,18 @@ docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN - sleep 1 # grab tcpdump pcaps for debugging -docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & -docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & -docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & +docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & +docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 & +docker exec host4 ncat -nkluv 0.0.0.0 4000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & @@ -119,11 +120,11 @@ echo echo " *** Testing conntrack" echo set -x -# host2 can ping host3 now that host3 pinged it first -docker exec host2 ping -c1 192.168.100.3 -# host4 can ping host2 once conntrack established -docker exec host2 ping -c1 192.168.100.4 -docker exec host4 ping -c1 192.168.100.2 + +# host2 speaking to host4 on UDP 4000 should allow it to reply, when firewall rules would normally not permit this +docker exec host2 sh -c "/usr/bin/echo host2 | ncat -nuv 192.168.100.4 4000" +docker exec host2 ncat -e '/usr/bin/echo helloagainfromhost2' -nkluv 0.0.0.0 4000 & +docker exec host4 sh -c "/usr/bin/echo host4 | ncat -nuv 192.168.100.2 4000" docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' diff --git a/firewall.go b/firewall.go index da819221..2d67acbb 100644 --- a/firewall.go +++ b/firewall.go @@ -480,7 +480,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics { } } -// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new +// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new // firewall object is created func (f *Firewall) Destroy() { //TODO: clean references if/when needed diff --git a/firewall/packet.go b/firewall/packet.go index 40c7fc5d..2cbfb5ea 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -22,7 +22,10 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP. + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool @@ -46,6 +49,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = "tcp" case ProtoICMP: proto = "icmp" + case ProtoICMPv6: + proto = "icmpv6" case ProtoUDP: proto = "udp" default: diff --git a/firewall_test.go b/firewall_test.go index 934a90a4..a2133760 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -735,6 +735,150 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } +func TestFirewall_ICMPPortBehavior(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + //different ID is blocked + p.RemotePort++ + require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + +} + func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} diff --git a/outside.go b/outside.go index 172c3e83..b2cbf123 100644 --- a/outside.go +++ b/outside.go @@ -327,13 +327,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { proto := layers.IPProtocol(data[protoAt]) switch proto { - case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + case layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) fp.RemotePort = 0 fp.LocalPort = 0 fp.Fragment = false return nil + case layers.IPProtocolICMPv6: + if dataLen < offset+6 { + return ErrIPv6PacketTooShort + } + fp.Protocol = uint8(proto) + fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6 + icmptype := data[offset+1] + switch icmptype { + case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: + fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier + default: + fp.RemotePort = 0 + } + fp.Fragment = false + return nil + case layers.IPProtocolTCP, layers.IPProtocolUDP: if dataLen < offset+4 { return ErrIPv6PacketTooShort @@ -423,34 +439,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Accounting for a variable header length, do we have enough data for our src/dst tuples? minLen := ihl - if !fp.Fragment && fp.Protocol != firewall.ProtoICMP { - minLen += minFwPacketLen + if !fp.Fragment { + if fp.Protocol == firewall.ProtoICMP { + minLen += minFwPacketLen + 2 + } else { + minLen += minFwPacketLen + } } + if len(data) < minLen { return ErrIPv4InvalidHeaderLength } - // Firewall packets are locally oriented - if incoming { + if incoming { // Firewall packets are locally oriented fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } + } + + if fp.Fragment { + fp.RemotePort = 0 + fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = 0 //code would be uint16(data[ihl+1]) + } else if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } return nil diff --git a/outside_test.go b/outside_test.go index 2a750d40..042ccbb3 100644 --- a/outside_test.go +++ b/outside_test.go @@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) { // next layer, missing length byte err = newPacket(buffer.Bytes()[:49], true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + err = nil // A good ICMP packet ip = layers.IPv6{ @@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) { DstIP: net.IPv6linklocalallnodes, } - icmp := layers.ICMPv6{} - - buffer.Clear() - err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) - if err != nil { - panic(err) + icmp := layers.ICMPv6{ + TypeCode: layers.ICMPv6TypeEchoRequest, + Checksum: 0x1234, } - err = newPacket(buffer.Bytes(), true, p) - require.NoError(t, err) + buffer.Clear() + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp)) + require.Error(t, newPacket(buffer.Bytes(), true, p)) + + buffer.Clear() + echo := layers.ICMPv6Echo{ + Identifier: 0xabcd, + SeqNumber: 1234, + } + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo)) + require.NoError(t, newPacket(buffer.Bytes(), true, p)) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) - assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0xabcd), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) From 7760ccefbaac6d47512e98adc543ca5cfa103d8c Mon Sep 17 00:00:00 2001 From: "Jay R. Wren" Date: Fri, 6 Mar 2026 14:03:32 -0500 Subject: [PATCH 09/44] fix logging copy pasta (#1621) --- firewall.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firewall.go b/firewall.go index 2d67acbb..93b16891 100644 --- a/firewall.go +++ b/firewall.go @@ -230,7 +230,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.OutSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") fw.OutSendReject = false } From 1aa1a0476f6b69a7336684dd7373aeb3b2cb1b18 Mon Sep 17 00:00:00 2001 From: "Jay R. Wren" Date: Mon, 16 Mar 2026 17:07:40 -0400 Subject: [PATCH 10/44] #ECCN:Open Source in CODEOWNERS (#1632) Salesforce is requesting this in all opensource repositories --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..00cd7bd1 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +#ECCN:Open Source From 9f1aef53fae98a77d2f8372b2588c7e9801f8a3c Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 23 Mar 2026 11:15:40 -0400 Subject: [PATCH 11/44] Fix dissector logic (#1626) * Fix typo in Wireshark dissector * Fix wireshark dissector prefs_changed logic The previous logic had several issues: - Changing only the port number (without toggling all_ports) would not re-register the dissector on the new port. - Turning all_ports off would remove all registrations but only re-add the specific port inside a branch that also required all_ports to have changed, and never updated default_settings.port. Simplify to: remove all registrations, then register based on current prefs, then update the cached state. --- dist/wireshark/nebula.lua | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/dist/wireshark/nebula.lua b/dist/wireshark/nebula.lua index ddc808f9..d17dc7a0 100644 --- a/dist/wireshark/nebula.lua +++ b/dist/wireshark/nebula.lua @@ -84,30 +84,24 @@ end function nebula.prefs_changed() if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then - -- Nothing changed, bail return end - -- Remove our old dissector + -- Remove all existing registrations DissectorTable.get("udp.port"):remove_all(nebula) - if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then - default_settings.all_port = nebula.prefs.all_ports - + if nebula.prefs.all_ports then + -- Register on every port for hole punch capture for i=0, 65535 do DissectorTable.get("udp.port"):add(i, nebula) end - - -- no need to establish again on specific ports - return + else + -- Register on the configured port only + DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula) end - - if default_settings.all_ports ~= nebula.prefs.all_ports then - -- Add our new port dissector - default_settings.port = nebula.prefs.port - DissectorTable.get("udp.port"):add(default_settings.port, nebula) - end + default_settings.all_ports = nebula.prefs.all_ports + default_settings.port = nebula.prefs.port end DissectorTable.get("udp.port"):add(default_settings.port, nebula) From 91d1f4675ad05d43a4c69e05a9b287edccbddd26 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 25 Mar 2026 11:59:37 -0500 Subject: [PATCH 12/44] properly handle closetunnel packets (#1638) --- e2e/tunnels_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++ outside.go | 7 +++ 2 files changed, 112 insertions(+) diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index e89cf869..e8e41945 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -12,6 +12,8 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" ) @@ -365,3 +367,106 @@ func TestCrossStackRelaysWork(t *testing.T) { //theirControl.Stop() //relayControl.Stop() } + +func TestCloseTunnelAuthenticated(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.Log("Close the tunnel") + myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false) + r.FlushAll() + + waitStart := time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 && theirIndexes == 0 { + break + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*6 { + t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds") + } + + time.Sleep(1 * time.Second) + //r.FlushAll() + } + + r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart)) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + r.Log("Assert another tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if hi == nil { + t.Fatal("There is no hostinfo for this tunnel") + } + myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if myHi == nil { + t.Fatal("There is no hostinfo for my tunnel") + } + r.Log("It does") + + buf := make([]byte, 1024) + hdr := header.H{ + Version: 1, + Type: header.CloseTunnel, + Subtype: 0, + Reserved: 0, + RemoteIndex: hi.RemoteIndex, + MessageCounter: 5, + } + out, err := hdr.Encode(buf) + if err != nil { + t.Fatal(err) + } + + pkt := &udp.Packet{ + To: hi.CurrentRemote, + From: myHi.CurrentRemote, + Data: out, + } + r.InjectUDPPacket(myControl, theirControl, pkt) + r.Log("Injected bogus close tunnel. Let's see!") + waitStart = time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 { + t.Fatal("myIndexes should not be 0") + } + if theirIndexes == 0 { + t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet") + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*4 { + t.Log("The tunnel would have been gone by now") + break + } + + time.Sleep(1 * time.Second) + r.FlushAll() + } + + myControl.Stop() + theirControl.Stop() +} diff --git a/outside.go b/outside.go index b2cbf123..eba9d887 100644 --- a/outside.go +++ b/outside.go @@ -190,6 +190,13 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !f.handleEncrypted(ci, via, h) { return } + _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("from", via). + WithField("packet", packet). + Error("Failed to decrypt CloseTunnel packet") + return + } hostinfo.logger(f.l).WithField("from", via). Info("Close tunnel received, tearing down.") From 951d368faf95138c53e492dd788e7dfc1b22e9b7 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 30 Mar 2026 16:20:21 -0400 Subject: [PATCH 13/44] Add a small link to DN Managed Nebula (#1641) * Add a small link to DN Managed Nebula Also link the mobile source code --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fab9cff1..7cbcb412 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for docker pull nebulaoss/nebula ``` -#### Mobile +#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula)) - [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200) - [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1) @@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec ## Getting started (quickly) +**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you. + To set up a Nebula network, you'll need: #### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use. From f8587956babc28da162afaba9861c282dc90a90e Mon Sep 17 00:00:00 2001 From: "Jay R. Wren" Date: Fri, 3 Apr 2026 09:37:18 -0400 Subject: [PATCH 14/44] add sshd.sandbox_dir config option (#1622) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add sshd.sandbox_dir config option Sanitize SSH profile paths (ssh.go:514,683,719) — restrict os.Create(a[0]) to a safe directory. Add a config option in the config file to specify the sandbox directory. For backwards compatibility, if the config is not specified, keep the current behavior. * update default and example * use os.TempDir() for sshd.sandbox_dir default * split sandbox path validation into separate conditionals Separate the combined && check in sshSanitizeFilePath into two distinct conditionals with specific error messages: one for paths resolving to the sandbox directory itself, and one for paths outside the sandbox. Co-Authored-By: Claude * fix: trim leading zeros from p256 signature swap result bigmod.Nat.Bytes() returns fixed-size 32-byte slices, but ASN.1 INTEGER parsing strips leading zeros. This caused a flaky test failure (~1/256 chance) when the S value's high byte was zero. Co-Authored-By: Claude --------- Co-authored-by: Claude --- cert/p256/p256.go | 7 ++++- examples/config.yml | 6 ++++ ssh.go | 71 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/cert/p256/p256.go b/cert/p256/p256.go index be0a2381..dc609a35 100644 --- a/cert/p256/p256.go +++ b/cert/p256/p256.go @@ -44,7 +44,12 @@ func swap(r, s []byte) ([]byte, []byte, error) { } sNormalized := nMod.Nat().Sub(bigS, nMod) - return r, sNormalized.Bytes(nMod), nil + result := sNormalized.Bytes(nMod) + for len(result) > 1 && result[0] == 0 { + result = result[1:] + } + + return r, result, nil } func Normalize(sig []byte) ([]byte, error) { diff --git a/examples/config.yml b/examples/config.yml index 1f9dc2a4..5bb87d8e 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -204,6 +204,12 @@ punchy: # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access. #trusted_cas: #- "ssh public key string" + # sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile, + # save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory, + # and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug. + # The directory is NOT automatically created. + # Overriding this to "" is the same as "/" and will allow overwriting any path on the host. + #sandbox_dir: /var/tmp/nebula-debug # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: diff --git a/ssh.go b/ssh.go index 0a9adb51..b2912d55 100644 --- a/ssh.go +++ b/ssh.go @@ -10,6 +10,7 @@ import ( "net" "net/netip" "os" + "path/filepath" "reflect" "runtime" "runtime/pprof" @@ -188,6 +189,12 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro } func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { + // sandboxDir defaults to a dir in temp. The intention is that end user will + // create this dir as needed. Overriding this config value to "" allows + // writing to anywhere in the system. + defaultDir := filepath.Join(os.TempDir(), "nebula-debug") + sandboxDir := c.GetString("sshd.sandbox_dir", defaultDir) + ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -246,7 +253,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", - Callback: sshStartCpuProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshStartCpuProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -261,7 +270,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", - Callback: sshGetHeapProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetHeapProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -273,7 +284,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-mutex-profile", ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", - Callback: sshGetMutexProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetMutexProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -506,13 +519,43 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er return nil } -func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { +// sshSanitizeFilePath validates that the given file path is within the sandbox directory. +// If sandboxDir is empty, the path is returned as-is for backwards compatibility. +func sshSanitizeFilePath(sandboxDir, filePath string) (string, error) { + if sandboxDir == "" { + return filePath, nil + } + + // Clean and resolve the path relative to the sandbox directory + if !filepath.IsAbs(filePath) { + filePath = filepath.Join(sandboxDir, filePath) + } + cleaned := filepath.Clean(filePath) + + // Ensure the resolved path is within the sandbox directory + cleanedSandbox := filepath.Clean(sandboxDir) + if cleaned == cleanedSandbox { + return "", fmt.Errorf("path %q resolves to the sandbox directory itself %q", filePath, sandboxDir) + } + if !strings.HasPrefix(cleaned, cleanedSandbox+string(filepath.Separator)) { + return "", fmt.Errorf("path %q is outside the sandbox directory %q", filePath, sandboxDir) + } + + return cleaned, nil +} + +func sshStartCpuProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -676,12 +719,17 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e return w.WriteLine("Changed") } -func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -712,12 +760,17 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) } From 6727113b2b69cfc5b8c91898d671ba5e7dc74ba6 Mon Sep 17 00:00:00 2001 From: "Jay R. Wren" Date: Mon, 6 Apr 2026 12:24:28 -0400 Subject: [PATCH 15/44] gh workflow release: protect from ref_name attack (#1650) It is not likely, but better to be safe. --- .github/workflows/release.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9ce1d5e3..a5e8d397 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -209,10 +209,11 @@ jobs: id: create_release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REF_NAME: ${{ github.ref_name }} run: | cd artifacts gh release create \ --verify-tag \ - --title "Release ${{ github.ref_name }}" \ - "${{ github.ref_name }}" \ + --title "Release ${GITHUB_REF_NAME}" \ + "${GITHUB_REF_NAME}" \ SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz From 0ad5c771e91a1b1f341564cb8e060451b947ee7e Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 13 Apr 2026 13:19:55 -0400 Subject: [PATCH 16/44] Refactor CA pool handling to use streaming (#1644) Co-authored-by: maggie44 <64841595+maggie44@users.noreply.github.com> Co-authored-by: JackDoan --- cert/ca_pool.go | 45 +++++++++--- cert/ca_pool_test.go | 57 ++++++++++++++++ cert/pem.go | 82 ++++++++++++++++++---- cert/pem_test.go | 76 +++++++++++++++++++++ cmd/nebula-cert/verify.go | 17 ++--- cmd/nebula-cert/verify_test.go | 2 +- pki.go | 15 ++-- pki_hup_benchmark_test.go | 121 +++++++++++++++++++++++++++++++++ 8 files changed, 373 insertions(+), 42 deletions(-) create mode 100644 pki_hup_benchmark_test.go diff --git a/cert/ca_pool.go b/cert/ca_pool.go index e9903e1f..792f8e66 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -1,11 +1,14 @@ package cert import ( + "bufio" + "bytes" + "encoding/pem" "errors" "fmt" + "io" "net/netip" "slices" - "strings" "time" ) @@ -29,22 +32,46 @@ func NewCAPool() *CAPool { // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs)) +} + +// NewCAPoolFromPEMReader will create a new CA pool from the provided reader. +// The reader must contain a PEM-encoded set of nebula certificates. +func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { pool := NewCAPool() - var err error + var expired bool - for { - caPEMs, err = pool.AddCAFromPEM(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil + + scanner := bufio.NewScanner(r) + scanner.Split(SplitPEM) + + for scanner.Scan() { + pemBytes := scanner.Bytes() + + block, rest := pem.Decode(pemBytes) + if len(bytes.TrimSpace(rest)) > 0 { + return nil, ErrInvalidPEMBlock } + if block == nil { + return nil, ErrInvalidPEMBlock + } + + c, err := unmarshalCertificateBlock(block) if err != nil { return nil, err } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break + + err = pool.AddCA(c) + if errors.Is(err, ErrExpired) { + expired = true + continue + } else if err != nil { + return nil, err } } + if err := scanner.Err(); err != nil { + return nil, ErrInvalidPEMBlock + } if expired { return pool, ErrExpired diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index e872c7d4..ab173228 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,7 +1,10 @@ package cert import ( + "bytes" + "io" "net/netip" + "strings" "testing" "time" @@ -112,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, ppppp.CAs, 1) } +// oneByteReader wraps a reader to return at most 1 byte per Read call, +// exercising the streaming accumulation logic in NewCAPoolFromPEMReader. +type oneByteReader struct { + r io.Reader +} + +func (o *oneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return o.r.Read(p[:1]) +} + +func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) { + pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil)) + require.NoError(t, err) + assert.Empty(t, pool.CAs) + + pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n ")) + require.NoError(t, err) + assert.Empty(t, pool.CAs) +} + +func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) { + ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, pem2...) + pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)}) + require.NoError(t, err) + assert.Len(t, pool.CAs, 2) + + fp1, err := ca1.Fingerprint() + require.NoError(t, err) + fp2, err := ca2.Fingerprint() + require.NoError(t, err) + + assert.Contains(t, pool.CAs, fp1) + assert.Contains(t, pool.CAs, fp2) +} + +func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) { + _, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata")) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + +func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) { + _, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, []byte("some trailing garbage")...) + _, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle)) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + func TestCertificateV1_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) diff --git a/cert/pem.go b/cert/pem.go index 8942c23a..84221b22 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -1,12 +1,66 @@ package cert import ( + "bytes" "encoding/pem" + "errors" "fmt" "golang.org/x/crypto/ed25519" ) +var ErrTruncatedPEMBlock = errors.New("truncated PEM block") + +// SplitPEM is a split function for bufio.Scanner that returns each PEM block. +func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Look for the start of a PEM block + start := bytes.Index(data, []byte("-----BEGIN ")) + if start == -1 { + if atEOF && len(bytes.TrimSpace(data)) > 0 { + // Non-whitespace content with no PEM block + return 0, nil, ErrTruncatedPEMBlock + } + if atEOF { + return len(data), nil, nil + } + // Request more data + return 0, nil, nil + } + + // Look for the end marker + endMarkerStart := bytes.Index(data[start:], []byte("-----END ")) + if endMarkerStart == -1 { + if atEOF { + // Incomplete PEM block at EOF + return 0, nil, ErrTruncatedPEMBlock + } + // Need more data to find the end + return 0, nil, nil + } + + // Find the actual end of the END line (after the newline) + endMarkerStart += start + endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n') + var end int + if endLineEnd == -1 { + if atEOF { + // END marker without newline at EOF - take it anyway + end = len(data) + } else { + // Need more data + return 0, nil, nil + } + } else { + end = endMarkerStart + endLineEnd + 1 + } + + // Extract the PEM block + pemBlock := data[start:end] + + // Return the valid PEM block + return end, pemBlock, nil +} + const ( //cert banners CertificateBanner = "NEBULA CERTIFICATE" CertificateV2Banner = "NEBULA CERTIFICATE V2" @@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } - var c Certificate - var err error - - switch p.Type { - // Implementations must validate the resulting certificate contains valid information - case CertificateBanner: - c, err = unmarshalCertificateV1(p.Bytes, nil) - case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) - default: - return nil, r, ErrInvalidPEMCertificateBanner - } - + c, err := unmarshalCertificateBlock(p) if err != nil { return nil, r, err } @@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +// unmarshalCertificateBlock decodes a single PEM block into a certificate. +// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise. +func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) { + switch block.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + return unmarshalCertificateV1(block.Bytes, nil) + case CertificateV2Banner: + return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519) + default: + return nil, ErrInvalidPEMCertificateBanner + } +} + func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) diff --git a/cert/pem_test.go b/cert/pem_test.go index 310c57a3..ff623541 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -1,12 +1,88 @@ package cert import ( + "bufio" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func scanAll(t *testing.T, input string) ([]string, error) { + t.Helper() + scanner := bufio.NewScanner(strings.NewReader(input)) + scanner.Split(SplitPEM) + var blocks []string + for scanner.Scan() { + blocks = append(blocks, scanner.Text()) + } + return blocks, scanner.Err() +} + +func TestSplitPEM_Single(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_Multiple(t *testing.T) { + block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n" + block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, block1+block2) + require.NoError(t, err) + require.Len(t, blocks, 2) + require.Equal(t, block1, blocks[0]) + require.Equal(t, block2, blocks[1]) +} + +func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) { + input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 2) +} + +func TestSplitPEM_Empty(t *testing.T) { + blocks, err := scanAll(t, "") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_WhitespaceOnly(t *testing.T) { + blocks, err := scanAll(t, " \n\t\n ") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_TrailingGarbage(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage" + blocks, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) + require.Len(t, blocks, 1) +} + +func TestSplitPEM_TruncatedBlock(t *testing.T) { + input := "-----BEGIN TEST-----\npartial data with no end" + _, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + +func TestSplitPEM_NoEndNewline(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_GarbageOnly(t *testing.T) { + _, err := scanAll(t, "this is not PEM data") + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + func TestUnmarshalCertificateFromPEM(t *testing.T) { goodCert := []byte(` # A good cert diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index bea4d1d9..36258dd8 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "os" - "strings" "time" "github.com/slackhq/nebula/cert" @@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := os.ReadFile(*vf.caPath) + caFile, err := os.Open(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } + defer caFile.Close() - caPool := cert.NewCAPool() - for { - rawCACert, err = caPool.AddCAFromPEM(rawCACert) - if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %w", err) - } - - if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { - break - } + caPool, err := cert.NewCAPoolFromPEMReader(caFile) + if err != nil && !errors.Is(err, cert.ErrExpired) { + return fmt.Errorf("error while adding ca cert to pool: %w", err) } rawCert, err := os.ReadFile(*vf.certPath) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f555e5f5..1aa5e8e6 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -64,7 +64,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.ErrorIs(t, err, cert.ErrInvalidPEMBlock) // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) diff --git a/pki.go b/pki.go index 19869d58..0639fd3d 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/netip" "os" @@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { - var rawCA []byte - var err error - caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) + var caReader io.ReadCloser + var err error + if strings.Contains(caPathOrPEM, "-----BEGIN") { + caReader = io.NopCloser(strings.NewReader(caPathOrPEM)) } else { - rawCA, err = os.ReadFile(caPathOrPEM) + caReader, err = os.Open(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEM(rawCA) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go new file mode 100644 index 00000000..39f648ff --- /dev/null +++ b/pki_hup_benchmark_test.go @@ -0,0 +1,121 @@ +package nebula + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + cert_test "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/require" +) + +func BenchmarkReloadConfigWithCAs(b *testing.B) { + prevProcs := runtime.GOMAXPROCS(1) + b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) }) + + for _, size := range []int{100, 250, 500, 1000, 5000} { + b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) { + l := test.NewLogger() + dir := b.TempDir() + + ca, caKey, caBundle := buildCABundle(b, size) + caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle) + + configBody := fmt.Sprintf(`pki: + ca: %s + cert: %s + key: %s +`, caPath, certPath, keyPath) + + configPath := filepath.Join(dir, "config.yml") + require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600)) + + c := config.NewC(l) + require.NoError(b, c.Load(dir)) + + _, err := NewPKIFromConfig(l, c) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + c.ReloadConfig() + } + }) + } +} + +func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) { + b.Helper() + require.GreaterOrEqual(b, count, 1) + + before := time.Now().Add(-24 * time.Hour) + after := time.Now().Add(24 * time.Hour) + + ca, _, caKey, pem := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + before, + after, + nil, + nil, + nil, + ) + + buf := bytes.NewBuffer(pem) + buf.Write([]byte("\n# a comment!\n")) + + for i := 1; i < count; i++ { + _, _, _, extraPEM := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + time.Now(), + time.Now().Add(time.Hour), + nil, + nil, + nil, + ) + buf.Write([]byte("\n# a comment!\n")) + buf.Write(extraPEM) + } + + return ca, caKey, buf.Bytes() +} + +func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) { + b.Helper() + + networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")} + + _, _, keyPEM, certPEM := cert_test.NewTestCert( + cert.Version2, + cert.Curve_CURVE25519, + ca, + caKey, + "reload-benchmark", + time.Now(), + time.Now().Add(time.Hour), + networks, + nil, + nil, + ) + + caPath := filepath.Join(dir, "ca.pem") + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(b, os.WriteFile(caPath, caBundle, 0o600)) + require.NoError(b, os.WriteFile(certPath, certPEM, 0o600)) + require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600)) + + return caPath, certPath, keyPath +} From 3fae693c428daefca439b685695323a34bed66bc Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 14 Apr 2026 13:32:01 -0500 Subject: [PATCH 17/44] Additional e2e tests to assert current handshake behavior (#1653) --- e2e/handshake_manager_test.go | 565 ++++++++++++++++++++++++++++++++++ 1 file changed, 565 insertions(+) create mode 100644 e2e/handshake_manager_test.go diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go new file mode 100644 index 00000000..3fe784c1 --- /dev/null +++ b/e2e/handshake_manager_test.go @@ -0,0 +1,565 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" +) + +// makeHandshakePacket creates a handshake packet with the given parameters. +func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet { + data := make([]byte, 200) + header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter) + for i := header.Len; i < len(data); i++ { + data[i] = byte(i) + } + return &udp.Packet{To: to, From: from, Data: data} +} + +func TestHandshakeRetransmitDuplicate(t *testing.T) { + // Verify the responder correctly handles receiving the same msg1 multiple times + // (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen + // and the cached response is resent. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Grab my msg1") + msg1 := myControl.GetFromUDP(true) + + t.Log("Inject msg1 into them, first time") + theirControl.InjectUDPPacket(msg1) + _ = theirControl.GetFromUDP(true) + + t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path") + theirControl.InjectUDPPacket(msg1) + resp2 := theirControl.GetFromUDP(true) + assert.NotNil(t, resp2, "should get cached response on duplicate msg1") + + t.Log("Complete handshake with cached response") + myControl.InjectUDPPacket(resp2) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify only one tunnel exists on each side") + assert.Len(t, myControl.ListHostmapHosts(false), 1) + assert.Len(t, theirControl.ListHostmapHosts(false), 1) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeTruncatedPacketRecovery(t *testing.T) { + // Verify that a truncated handshake packet is ignored and the real + // packet can still complete the handshake. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Get msg1 and deliver to responder") + msg1 := myControl.GetFromUDP(true) + theirControl.InjectUDPPacket(msg1) + + t.Log("Get the real response") + realResp := theirControl.GetFromUDP(true) + + t.Log("Truncate the response and inject, should be ignored") + truncResp := realResp.Copy() + truncResp.Data = truncResp.Data[:header.Len] + myControl.InjectUDPPacket(truncResp) + + t.Log("Verify pending handshake survived the truncated packet") + assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist") + + t.Log("Inject real response, should complete handshake") + myControl.InjectUDPPacket(realResp) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain and verify tunnel") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { + // A msg2 arriving with no matching pending index should be silently dropped + // with no response sent and no state changes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + myIndexes := len(myControl.ListHostmapIndexes(false)) + + t.Log("Inject a fake msg2 with unknown RemoteIndex") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2)) + + t.Log("Verify no new indexes created") + assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false))) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2") + + t.Log("Verify existing tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownMessageCounter(t *testing.T) { + // A handshake packet with an unexpected message counter should be silently + // dropped with no side effects and no UDP response. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with MessageCounter=3") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3)) + + t.Log("Inject handshake with MessageCounter=99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownSubtype(t *testing.T) { + // A handshake packet with an unknown subtype should be silently dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with unknown subtype 99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeLateResponse(t *testing.T) { + // After a handshake times out, a late response should be silently ignored + // with no new tunnels created. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "handshakes": m{ + "try_interval": "200ms", + "retries": 2, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Grab msg1 but don't deliver") + msg1 := myControl.GetFromUDP(true) + + t.Log("Wait for handshake to time out") + for i := 0; i < 5; i++ { + time.Sleep(300 * time.Millisecond) + myControl.GetFromUDP(false) + } + + t.Log("Confirm no pending handshakes remain") + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Deliver old msg1 to them, they create a tunnel") + theirControl.InjectUDPPacket(msg1) + resp := theirControl.GetFromUDP(true) + assert.NotNil(t, resp) + + t.Log("Inject late response into me, should be ignored") + myControl.InjectUDPPacket(resp) + + t.Log("No tunnel should exist on my side") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeSelfConnectionRejected(t *testing.T) { + // Verify that a node rejects a handshake containing its own VPN IP in the + // peer cert. We do this by sending the initiator's own msg1 back to itself. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + + // Need a lighthouse entry to trigger a handshake + myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242")) + + myControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + msg1 := myControl.GetFromUDP(true) + + t.Log("Drain any handshake retransmits before injecting") + time.Sleep(100 * time.Millisecond) + for myControl.GetFromUDP(false) != nil { + } + + t.Log("Feed my own msg1 back to me as if it came from someone else") + selfMsg := msg1.Copy() + selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242") + selfMsg.To = myUdpAddr + myControl.InjectUDPPacket(selfMsg) + + t.Log("Verify no response was sent (self-connection rejected)") + time.Sleep(100 * time.Millisecond) + // Drain any further retransmits from the original handshake, then check + // that none of them are a handshake response (MessageCounter=2) + h := &header.H{} + for { + p := myControl.GetFromUDP(false) + if p == nil { + break + } + _ = h.Parse(p.Data) + assert.NotEqual(t, uint64(2), h.MessageCounter, + "should not send a stage 2 response to self-connection") + } + + t.Log("Verify no tunnel to myself was created") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)) + + myControl.Stop() +} + +func TestHandshakeMessageCounter0Dropped(t *testing.T) { + // MessageCounter=0 is not a valid handshake message and should be dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + _, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + + t.Log("Inject handshake with MessageCounter=0") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0)) + + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() +} + +func TestHandshakeRemoteAllowList(t *testing.T) { + // Verify that a handshake from a blocked underlay IP is dropped with no + // response and no state changes. Then verify the same packet from an + // allowed IP succeeds. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "lighthouse": m{ + "remote_allow_list": m{ + "10.0.0.0/8": true, + "0.0.0.0/0": false, + }, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from them") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")) + msg1 := theirControl.GetFromUDP(true) + + t.Log("Rewrite the source to a blocked IP and inject") + blockedMsg := msg1.Copy() + blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242") + myControl.InjectUDPPacket(blockedMsg) + + t.Log("Verify no tunnel, no pending, no response from blocked source") + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source") + + t.Log("Now inject the real packet from the allowed source") + myControl.InjectUDPPacket(msg1) + + t.Log("Verify handshake completes from allowed source") + resp := myControl.GetFromUDP(true) + assert.NotNil(t, resp) + theirControl.InjectUDPPacket(resp) + theirControl.WaitForType(1, 0, myControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { + // When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel + // remains functional and hostmap index count is stable. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake via the router") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + theirIndexes := len(theirControl.ListHostmapIndexes(false)) + hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi) + originalRemote := hi.CurrentRemote + + t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")) + r.RouteForAllUntilTxTun(theirControl) + + t.Log("Verify tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify remote is still valid and index count is stable") + hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi2) + assert.Equal(t, originalRemote, hi2.CurrentRemote) + assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)), + "no extra indexes should be created from ErrAlreadySeen") + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeWrongResponderPacketStore(t *testing.T) { + // Verify that when the wrong host responds, the cached packets are + // transferred to the new handshake, the evil tunnel is closed, evil's + // address is blocked, and the correct tunnel is eventually established. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) + + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Send multiple packets to them (cached during handshake)") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")) + + t.Log("Route until evil tunnel is closed") + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + if err := h.Parse(p.Data); err != nil { + panic(err) + } + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + return router.KeepRouting + }) + + t.Log("Verify evil's address is blocked in the new pending handshake") + pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + if pendingHI != nil { + assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr, + "evil's address should be blocked") + } + + t.Log("Inject correct lighthouse addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + t.Log("Route until cached packets arrive at the real them") + p := r.RouteForAllUntilTxTun(theirControl) + assert.NotNil(t, p, "a cached packet should be delivered to the correct host") + + t.Log("Verify the correct host has a tunnel") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + t.Log("Verify no hostinfo artifacts from evil remain") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true), + "no pending hostinfo for evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false), + "no main hostinfo for evil") + + myControl.Stop() + theirControl.Stop() + evilControl.Stop() +} + +func TestHandshakeRelayComplete(t *testing.T) { + // Verify that a relay handshake completes correctly and relay state is + // properly maintained on all three nodes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger handshake via relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Verify bidirectional tunnel via relay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify relay state on my side shows relay-to-me") + myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + assert.NotNil(t, myHI) + assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them") + + t.Log("Verify relay state on their side shows relay-to-me") + theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, theirHI) + assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me") + + t.Log("Verify relay node shows through-me relays") + relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, relayHI) + + myControl.Stop() + relayControl.Stop() + theirControl.Stop() +} + +// NOTE: Relay V1 cert + IPv6 rejection is not tested here because +// InjectTunUDPPacket from a V4 node to a V6 address panics in the test +// framework. The check is in handshake_manager.go handleOutbound relay +// logic (lines ~304-313): if the relay host has a V1 cert and either +// address is IPv6, the relay is skipped. + +// NOTE: Relay reestablishment (Disestablished state transition) is covered +// by the existing TestReestablishRelays in handshakes_test.go. From b3194236aac4d4a577812fbd3f1ec66c2e5e1a60 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Tue, 14 Apr 2026 18:25:24 -0500 Subject: [PATCH 18/44] udp_linux: wrap socket operations with syscall.RawConn for clean teardown (#1654) remove runtime.LockOSThread() because it makes things worse now remove the "custom" Write() method from tun_linux.go, the stdlib path via os.File performs better We should change our guidance around number of routines, ~2 per thread (that you wish to use for Nebula) seems to be about right now --- interface.go | 5 - overlay/tun_linux.go | 23 ---- udp/udp_linux.go | 322 ++++++++++++++++++++++--------------------- 3 files changed, 162 insertions(+), 188 deletions(-) diff --git a/interface.go b/interface.go index 61b1f228..61f8c9b7 100644 --- a/interface.go +++ b/interface.go @@ -7,7 +7,6 @@ import ( "io" "net/netip" "os" - "runtime" "sync/atomic" "time" @@ -263,8 +262,6 @@ func (f *Interface) run() { } func (f *Interface) listenOut(i int) { - runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -285,8 +282,6 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..9d779a4b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -261,29 +261,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..b1490a1c 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,6 +4,7 @@ package udp import ( + "context" "encoding/binary" "fmt" "net" @@ -18,58 +19,58 @@ import ( ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + udpConn *net.UDPConn + rawConn syscall.RawConn + isV4 bool + l *logrus.Logger + batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true +func setReusePort(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + //CloseOnExec already set by the runtime + }) + if err != nil { + return err } - return ip, false + return opErr } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET + listen := netip.AddrPortFrom(ip, uint16(port)) + lc := net.ListenConfig{} + if multi { + lc.Control = setReusePort } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - + //this context is only used during the bind operation, you can't cancel it to kill the socket + pc, err := lc.ListenPacket(context.Background(), "udp", listen.String()) if err != nil { - unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } - - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + udpConn := pc.(*net.UDPConn) + rawConn, err := udpConn.SyscallConn() + if err != nil { + _ = udpConn.Close() + return nil, err + } + //gotta find out if we got an AF_INET6 socket or not: + out := &StdConn{ + udpConn: udpConn, + rawConn: rawConn, + l: l, + batch: batch, } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) + af, err := out.getSockOptInt(unix.SO_DOMAIN) + if err != nil { + _ = out.Close() + return nil, err } + out.isV4 = af == unix.AF_INET - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return out, nil } func (u *StdConn) SupportsMultipleReaders() bool { @@ -80,63 +81,137 @@ func (u *StdConn) Rebind() error { return nil } +func (u *StdConn) getSockOptInt(opt int) (int, error) { + if u.rawConn == nil { + return 0, fmt.Errorf("no UDP connection") + } + var out int + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + out, opErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) + }) + if err != nil { + return 0, err + } + return out, opErr +} + +func (u *StdConn) setSockOptInt(opt int, n int) error { + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, opt, n) + }) + if err != nil { + return err + } + return opErr +} + func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + return u.setSockOptInt(unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + return u.setSockOptInt(unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + return u.setSockOptInt(unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + return u.getSockOptInt(unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + return u.getSockOptInt(unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + return u.getSockOptInt(unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err - } + a := u.udpConn.LocalAddr() - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } -func (u *StdConn) ListenOut(r EncReader) { +func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { + var errno syscall.Errno + n, _, errno := unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + // No data available, block for I/O and try again. + return int(n), false, nil + } + if errno != 0 { + return int(n), true, &net.OpError{Op: "recvmmsg", Err: errno} + } + return int(n), true, nil +} + +func (u *StdConn) listenOutSingle(r EncReader) { + var err error + var n int + var from netip.AddrPort + buffer := make([]byte, MTU) + + for { + n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) + r(from, buffer[:n]) + } +} + +func (u *StdConn) listenOutBatch(r EncReader) { var ip netip.Addr + var n int + var operr error msgs, buffers, names := u.PrepareRawMessages(u.batch) - read := u.ReadMulti - if u.batch == 1 { - read = u.ReadSingle + + //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read + //defining it outside the loop so it gets re-used + reader := func(fd uintptr) (done bool) { + n, done, operr = recvmmsg(fd, msgs) + return done } for { - n, err := read(msgs) + err := u.rawConn.Read(reader) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } + if operr != nil { + u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") + return + } for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic @@ -150,106 +225,20 @@ func (u *StdConn) ListenOut(r EncReader) { } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil - } -} - -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} - } - - return int(n), nil +func (u *StdConn) ListenOut(r EncReader) { + if u.batch == 1 { + //save some ram by not calling PrepareRawMessages for fields we won't use + //we could also make this path more common by calling recvmmsg with msgs[:1], + //but that's still the recvmmsg syscall, which would be a change + u.listenOutSingle(r) + } else { + u.listenOutBatch(r) } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return ErrInvalidIPv6RemoteForSocket - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.udpConn.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -302,15 +291,28 @@ func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { + + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + _, _, syserr := unix.Syscall6(unix.SYS_GETSOCKOPT, fd, uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + if syserr != 0 { + opErr = syserr + } + }) + if err != nil { return err } - return nil + return opErr } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if u.udpConn != nil { + return u.udpConn.Close() + } + return nil } func NewUDPStatsEmitter(udpConns []Conn) func() { From a5e81efe7bab0c24758e300456992f4875955f3b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 15 Apr 2026 09:23:33 -0500 Subject: [PATCH 19/44] Try rsync from somewhere else (#1655) --- .github/workflows/smoke/smoke-vagrant.sh | 11 +++++++++++ .github/workflows/smoke/smoke.sh | 1 + .../workflows/smoke/vagrant-openbsd-amd64/Vagrantfile | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 1c1e3c50..115bbde0 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -29,6 +29,17 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up + +# OpenBSD: synced folders are disabled because Vagrant's rsync installer +# uses ftp.openbsd.org which no longer hosts packages for older releases. +# Copy build artifacts in via scp instead. +case "$1" in + openbsd-*) + vagrant ssh -c "sudo mkdir -p /nebula" -- -T + tar -cf - -C build . | vagrant ssh -c "sudo tar -xf - -C /nebula && sudo chmod -R a+r /nebula" -- -T + ;; +esac + vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 66164921..f8cda450 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -124,6 +124,7 @@ set -x # host2 speaking to host4 on UDP 4000 should allow it to reply, when firewall rules would normally not permit this docker exec host2 sh -c "/usr/bin/echo host2 | ncat -nuv 192.168.100.4 4000" docker exec host2 ncat -e '/usr/bin/echo helloagainfromhost2' -nkluv 0.0.0.0 4000 & +sleep 1 docker exec host4 sh -c "/usr/bin/echo host4 | ncat -nuv 192.168.100.2 4000" docker exec host4 sh -c 'kill 1' diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile index e4f41049..4bbf3e0e 100644 --- a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile @@ -3,5 +3,5 @@ Vagrant.configure("2") do |config| config.vm.box = "generic/openbsd7" - config.vm.synced_folder "../build", "/nebula", type: "rsync" + config.vm.synced_folder ".", "/vagrant", disabled: true end From 24c9c704a0ceea8ea511656f36ff6f21955d185b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:54:47 -0500 Subject: [PATCH 20/44] Bump github.com/miekg/dns from 1.1.70 to 1.1.72 (#1587) Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.70 to 1.1.72. - [Commits](https://github.com/miekg/dns/compare/v1.1.70...v1.1.72) --- updated-dependencies: - dependency-name: github.com/miekg/dns dependency-version: 1.1.72 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index f302f928..84616900 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 - github.com/miekg/dns v1.1.70 + github.com/miekg/dns v1.1.72 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 diff --git a/go.sum b/go.sum index f4b1074c..dba14624 100644 --- a/go.sum +++ b/go.sum @@ -85,8 +85,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA= -github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From f77fe741924ac047a70c669ed7eb8305d3756af1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 12:27:19 -0500 Subject: [PATCH 21/44] Bump github.com/miekg/pkcs11 (#1586) Bumps [github.com/miekg/pkcs11](https://github.com/miekg/pkcs11) from 1.1.2-0.20231115102856-9078ad6b9d4b to 1.1.2. - [Changelog](https://github.com/miekg/pkcs11/blob/master/release.go) - [Commits](https://github.com/miekg/pkcs11/commits/v1.1.2) --- updated-dependencies: - dependency-name: github.com/miekg/pkcs11 dependency-version: 1.1.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 84616900..63977295 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 github.com/miekg/dns v1.1.72 - github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b + github.com/miekg/pkcs11 v1.1.2 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 diff --git a/go.sum b/go.sum index dba14624..e3b317f9 100644 --- a/go.sum +++ b/go.sum @@ -87,8 +87,8 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ= +github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= From 36ab1dbb97e508bae056e6cac786cea15f4d98c8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:02:29 -0500 Subject: [PATCH 22/44] Bump the golang-x-dependencies group across 1 directory with 5 updates (#1629) Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync). Updates `golang.org/x/crypto` from 0.47.0 to 0.48.0 - [Commits](https://github.com/golang/crypto/compare/v0.47.0...v0.48.0) Updates `golang.org/x/net` from 0.49.0 to 0.51.0 - [Commits](https://github.com/golang/net/compare/v0.49.0...v0.51.0) Updates `golang.org/x/sync` from 0.19.0 to 0.20.0 - [Commits](https://github.com/golang/sync/compare/v0.19.0...v0.20.0) Updates `golang.org/x/sys` from 0.40.0 to 0.41.0 - [Commits](https://github.com/golang/sys/compare/v0.40.0...v0.41.0) Updates `golang.org/x/term` from 0.39.0 to 0.40.0 - [Commits](https://github.com/golang/term/compare/v0.39.0...v0.40.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-version: 0.48.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-version: 0.51.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sync dependency-version: 0.20.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-version: 0.41.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-version: 0.40.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 12 ++++++------ go.sum | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index 63977295..c51638a5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/slackhq/nebula -go 1.25 +go 1.25.0 require ( dario.cat/mergo v1.0.2 @@ -24,12 +24,12 @@ require ( github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 go.yaml.in/yaml/v3 v3.0.4 - golang.org/x/crypto v0.47.0 + golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.49.0 - golang.org/x/sync v0.19.0 - golang.org/x/sys v0.40.0 - golang.org/x/term v0.39.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index e3b317f9..825af666 100644 --- a/go.sum +++ b/go.sum @@ -164,8 +164,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -184,8 +184,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -193,8 +193,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -210,11 +210,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 72c04b90bd8ec45633a629ce30b629880211daf5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:27:14 -0500 Subject: [PATCH 23/44] Bump golang.zx2c4.com/wireguard/windows in the zx2c4-dependencies group (#1652) Bumps the zx2c4-dependencies group with 1 update: golang.zx2c4.com/wireguard/windows. Updates `golang.zx2c4.com/wireguard/windows` from 0.5.3 to 0.6.1 --- updated-dependencies: - dependency-name: golang.zx2c4.com/wireguard/windows dependency-version: 0.6.1 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: zx2c4-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index c51638a5..169cf1ca 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( golang.org/x/term v0.42.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b - golang.zx2c4.com/wireguard/windows v0.5.3 + golang.zx2c4.com/wireguard/windows v0.6.1 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe @@ -50,7 +50,7 @@ require ( github.com/prometheus/procfs v0.16.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.40.0 // indirect + golang.org/x/tools v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index 825af666..d56177b7 100644 --- a/go.sum +++ b/go.sum @@ -172,8 +172,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -225,8 +225,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -235,8 +235,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= -golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= -golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU= +golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= From 49e3c4649bcd4046f122ceb5f7985059b765faa0 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 17 Apr 2026 09:18:23 -0500 Subject: [PATCH 24/44] Try the hot new DefinedNet openbsd78 box (#1657) --- .github/workflows/smoke-extra.yml | 29 ++++++++++++++----- .github/workflows/smoke/smoke-vagrant.sh | 11 ------- .../Vagrantfile | 2 +- .../smoke/vagrant-openbsd-amd64/Vagrantfile | 4 +-- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index cdd6ea9d..3734db75 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -18,6 +18,8 @@ jobs: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') name: Run extra smoke tests runs-on: ubuntu-latest + env: + VAGRANT_DEFAULT_PROVIDER: libvirt steps: - uses: actions/checkout@v6 @@ -30,11 +32,13 @@ jobs: - name: add hashicorp source run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list - - name: workaround AMD-V issue # https://github.com/cri-o/packaging/pull/306 - run: sudo rmmod kvm_amd - - - name: install vagrant - run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox + - name: install vagrant and libvirt + run: | + sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev + sudo chmod 666 /dev/kvm + sudo usermod -aG libvirt $(whoami) + sudo chmod 666 /var/run/libvirt/libvirt-sock + vagrant plugin install vagrant-libvirt - name: freebsd-amd64 run: make smoke-vagrant/freebsd-amd64 @@ -45,10 +49,19 @@ jobs: - name: netbsd-amd64 run: make smoke-vagrant/netbsd-amd64 - - name: linux-386 - run: make smoke-vagrant/linux-386 - - name: linux-amd64-ipv6disable run: make smoke-vagrant/linux-amd64-ipv6disable + # linux-386 runs last because it requires disabling KVM to use VirtualBox, + # which prevents libvirt (used by the other tests) from working after this point. + - name: install virtualbox for i386 test + run: | + sudo apt-get install -y virtualbox + sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true + + - name: linux-386 + env: + VAGRANT_DEFAULT_PROVIDER: virtualbox + run: make smoke-vagrant/linux-386 + timeout-minutes: 30 diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 115bbde0..1c1e3c50 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -29,17 +29,6 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up - -# OpenBSD: synced folders are disabled because Vagrant's rsync installer -# uses ftp.openbsd.org which no longer hosts packages for older releases. -# Copy build artifacts in via scp instead. -case "$1" in - openbsd-*) - vagrant ssh -c "sudo mkdir -p /nebula" -- -T - tar -cf - -C build . | vagrant ssh -c "sudo tar -xf - -C /nebula && sudo chmod -R a+r /nebula" -- -T - ;; -esac - vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & diff --git a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile index 89f94772..eeb9679e 100644 --- a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile +++ b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "ubuntu/jammy64" + config.vm.box = "bento/ubuntu-24.04" config.vm.synced_folder "../build", "/nebula" diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile index 4bbf3e0e..6dd26373 100644 --- a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "generic/openbsd7" + config.vm.box = "DefinedNet/openbsd78" - config.vm.synced_folder ".", "/vagrant", disabled: true + config.vm.synced_folder "../build", "/nebula", type: "rsync" end From e80b9830a3a7aa0a7080fd6ebfd53b22cc70e6e4 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Mon, 20 Apr 2026 16:08:26 -0500 Subject: [PATCH 25/44] Remove more os.Exit calls and give a more reliable wait for stop function (attempt 3) (#1661) --- cmd/nebula-service/main.go | 16 ++- cmd/nebula/main.go | 16 ++- control.go | 78 ++++++++++- control_test.go | 1 + interface.go | 95 +++++++++---- main.go | 19 +-- overlay/tun_file_linux_test.go | 120 ++++++++++++++++ overlay/tun_linux.go | 242 ++++++++++++++++++++++++++++++--- service/service.go | 11 +- udp/conn.go | 6 +- udp/udp_darwin.go | 5 +- udp/udp_generic.go | 5 +- udp/udp_linux.go | 22 ++- udp/udp_rio_windows.go | 5 +- udp/udp_tester.go | 5 +- 15 files changed, 552 insertions(+), 94 deletions(-) create mode 100644 overlay/tun_file_linux_test.go diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..aaec80f7 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -78,8 +78,20 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.WithError(err).Error("Nebula stopped due to fatal error") + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..f29f4537 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -72,9 +72,21 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.WithError(err).Error("Nebula stopped due to fatal error") + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/control.go b/control.go index f8567b50..75eccef1 100644 --- a/control.go +++ b/control.go @@ -2,9 +2,11 @@ package nebula import ( "context" + "errors" "net/netip" "os" "os/signal" + "sync" "syscall" "github.com/sirupsen/logrus" @@ -13,6 +15,20 @@ import ( "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + StateUnknown RunState = iota + StateReady + StateStarted + StateStopping + StateStopped +) + +var ErrAlreadyStarted = errors.New("nebula is already started") +var ErrAlreadyStopped = errors.New("nebula cannot be restarted") +var ErrUnknownState = errors.New("nebula state is invalid") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,6 +42,9 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface l *logrus.Logger ctx context.Context @@ -49,10 +68,31 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function blocks until nebula has fully stopped and returns the +// first fatal reader error (if any). A nil error means nebula shut down +// gracefully; a non-nil error means a reader hit an unexpected failure that +// triggered the shutdown. +func (c *Control) Start() (func() error, error) { + c.stateLock.Lock() + defer c.stateLock.Unlock() + switch c.state { + case StateReady: + //yay! + case StateStopped, StateStopping: + return nil, ErrAlreadyStopped + case StateStarted: + return nil, ErrAlreadyStarted + default: + return nil, ErrUnknownState + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.state = StateStopped + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -71,16 +111,40 @@ func (c *Control) Start() { c.lighthouseStart() } + c.f.triggerShutdown = c.Stop + // Start reading packets. - c.f.run() + out, err := c.f.run() + if err != nil { + c.state = StateStopped + return nil, err + } + c.state = StateStarted + return out, nil +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != StateStarted { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = StateStopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() @@ -89,7 +153,9 @@ func (c *Control) Stop() { if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } - c.l.Info("Goodbye") + c.stateLock.Lock() + c.state = StateStopped + c.stateLock.Unlock() } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled diff --git a/control_test.go b/control_test.go index e8a5d312..558d8669 100644 --- a/control_test.go +++ b/control_test.go @@ -79,6 +79,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, &Interface{}) c := Control{ + state: StateReady, f: &Interface{ hostMap: hm, }, diff --git a/interface.go b/interface.go index 61f8c9b7..9e7a98a9 100644 --- a/interface.go +++ b/interface.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net/netip" - "os" + "sync" "sync/atomic" "time" @@ -87,6 +87,13 @@ type Interface struct { writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup + + // fatalErr holds the first unexpected reader error that caused shutdown. + // nil means "no fatal error" (yet) + fatalErr atomic.Pointer[error] + // triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr + triggerShutdown func() metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -209,7 +216,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -237,27 +244,54 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + f.wg.Add(1) // for us to wait on Close() to return + if err = f.inside.Activate(); err != nil { + f.wg.Done() f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func() error, error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { - go f.listenOut(i) + f.wg.Go(func() { + f.listenOut(i) + }) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { - go f.listenIn(f.readers[i], i) + f.wg.Go(func() { + f.listenIn(f.readers[i], i) + }) + } + + return func() error { + f.wg.Wait() + if e := f.fatalErr.Load(); e != nil { + return *e + } + return nil + }, nil +} + +// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one +func (f *Interface) onFatal(err error) { + swapped := f.fatalErr.CompareAndSwap(nil, &err) + if !swapped { + return + } + if f.triggerShutdown != nil { + f.triggerShutdown() } } @@ -276,9 +310,16 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + + if err != nil && !f.closed.Load() { + f.l.WithError(err).Error("Error while reading inbound packet, closing") + f.onFatal(err) + } + + f.l.Debugf("underlay reader %v is done", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -292,17 +333,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") + f.onFatal(err) } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + + f.l.Debugf("overlay reader %v is done", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -477,23 +518,23 @@ func (f *Interface) GetCertState() *CertState { } func (f *Interface) Close() error { + var errs []error f.closed.Store(true) - for _, u := range f.writers { + // Release the udp readers + for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).Error("Error while closing udp socket") - } - } - for i, r := range f.readers { - if i == 0 { - continue // f.readers[0] is f.inside, which we want to save for last - } - if err := r.Close(); err != nil { - f.l.WithError(err).Error("Error while closing tun reader") + f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket") + errs = append(errs, err) } } - // Release the tun device - return f.inside.Close() + // Release the tun device (closing the tun also closes all readers) + closeErr := f.inside.Close() + if closeErr != nil { + errs = append(errs, closeErr) + } + f.wg.Done() + return errors.Join(errs...) } diff --git a/main.go b/main.go index 74979417..8adc2921 100644 --- a/main.go +++ b/main.go @@ -288,15 +288,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + state: StateReady, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: statsStart, + dnsStart: dnsStart, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go new file mode 100644 index 00000000..5ab87e05 --- /dev/null +++ b/overlay/tun_file_linux_test.go @@ -0,0 +1,120 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + t.Cleanup(func() { _ = tf.Close() }) + + done := make(chan error, 1) + go func() { + _, err := tf.Read(make([]byte, 64)) + done <- err + }() + + // Verify Read is actually blocked in poll. + select { + case err := <-done: + t.Fatalf("Read returned before shutdown signal: %v", err) + case <-time.After(50 * time.Millisecond): + } + + if err := tf.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, os.ErrClosed) { + t.Fatalf("expected os.ErrClosed, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Read did not wake on shutdown") + } +} + +func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { + parent, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + friend, err := parent.newFriend(newReadPipe(t)) + if err != nil { + _ = parent.Close() + t.Fatalf("newFriend: %v", err) + } + t.Cleanup(func() { + _ = friend.Close() + _ = parent.Close() + }) + + readers := []*tunFile{parent, friend} + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r *tunFile) { + defer wg.Done() + _, errs[i] = r.Read(make([]byte, 64)) + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestTunFile_Close_Idempotent(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9d779a4b..2830ff6b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,6 +4,7 @@ package overlay import ( + "encoding/binary" "fmt" "io" "net" @@ -24,9 +25,175 @@ import ( "golang.org/x/sys/unix" ) +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFile struct { + fd int + shutdownFd int + lastOne bool + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed bool +} + +// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun +func (r *tunFile) newFriend(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + return &tunFile{ + fd: fd, + shutdownFd: r.shutdownFd, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + }, nil +} + +func newTunFd(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &tunFile{ + fd: fd, + shutdownFd: shutdownFd, + lastOne: true, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + + return out, nil +} + +func (r *tunFile) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) Read(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnRead(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) Write(buf []byte) (int, error) { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) + return err +} + +func (r *tunFile) Close() error { + if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem + return nil + } + r.closed = true + if r.lastOne { + _ = unix.Close(r.shutdownFd) + } + return unix.Close(r.fd) +} + type tun struct { - io.ReadWriteCloser - fd int + *tunFile + readers []*tunFile + closeLock sync.Mutex Device string vpnNetworks []netip.Prefix MaxMTU int @@ -72,9 +239,7 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err } @@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, &NameError{ Name: nameStr, Underlying: err, @@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { return nil, err } @@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. +func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + tunFile: tfd, + readers: []*tunFile{tfd}, + closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) - if err != nil { + if err = t.reload(c, true); err != nil { + _ = t.Close() return nil, err } @@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + t.closeLock.Lock() + defer t.closeLock.Unlock() + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + out, err := t.tunFile.newFriend(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } - return file, nil + t.readers = append(t.readers, out) + + return out, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -684,18 +866,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } func (t *tun) Close() error { + t.closeLock.Lock() + defer t.closeLock.Unlock() + if t.routeChan != nil { close(t.routeChan) + t.routeChan = nil } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() - } + // Signal all readers blocked in poll to wake up and exit + _ = t.tunFile.wakeForShutdown() if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - return nil + for i := range t.readers { + if i == 0 { + continue //we want to close the zeroth reader last + } + err := t.readers[i].Close() + if err != nil { + t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") + } else { + t.l.WithField("reader", i).Info("closed tun reader") + } + } + + //this is t.readers[0] too + err := t.tunFile.Close() + if err != nil { + t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") + } else { + t.l.WithField("reader", 0).Info("closed tun reader") + } + return err } diff --git a/service/service.go b/service/service.go index fc8ac97a..899e851d 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group so a fatal reader error + // propagates out through errgroup.Wait(). + eg.Go(func() error { + return wait() + }) + return &s, nil } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..863c98f3 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } u.l.WithError(err).Error("unexpected udp socket receive error") diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..ad26f794 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -73,7 +73,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -83,8 +83,7 @@ func (u *GenericConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { diff --git a/udp/udp_linux.go b/udp/udp_linux.go index b1490a1c..21a34147 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) { +func (u *StdConn) listenOutSingle(r EncReader) error { var err error var n int var from netip.AddrPort @@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) { for { n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) } } -func (u *StdConn) listenOutBatch(r EncReader) { +func (u *StdConn) listenOutBatch(r EncReader) error { var ip netip.Addr var n int var operr error @@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) { for { err := u.rawConn.Read(reader) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } if operr != nil { - u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") - return + return operr } for i := 0; i < n; i++ { @@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) { } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { if u.batch == 1 { - //save some ram by not calling PrepareRawMessages for fields we won't use - //we could also make this path more common by calling recvmmsg with msgs[:1], - //but that's still the recvmmsg syscall, which would be a change - u.listenOutSingle(r) + return u.listenOutSingle(r) } else { - u.listenOutBatch(r) + return u.listenOutBatch(r) } } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..607b978e 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) { if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..5db72555 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -6,6 +6,7 @@ package udp import ( "io" "net/netip" + "os" "sync/atomic" "github.com/sirupsen/logrus" @@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets if !ok { - return + return os.ErrClosed } r(p.From, p.Data) } From 3d34cc9b749869734da4674a2ce361db28bded0a Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 20 Apr 2026 16:38:14 -0500 Subject: [PATCH 26/44] Try to make smoke less flakey (#1663) --- .github/workflows/smoke/build-relay.sh | 8 ++-- .github/workflows/smoke/build.sh | 18 +++++--- .github/workflows/smoke/smoke-relay.sh | 57 ++++++++++++++++++++---- .github/workflows/smoke/smoke-vagrant.sh | 47 +++++++++++++++++-- .github/workflows/smoke/smoke.sh | 56 +++++++++++++++++++---- 5 files changed, 156 insertions(+), 30 deletions(-) diff --git a/.github/workflows/smoke/build-relay.sh b/.github/workflows/smoke/build-relay.sh index 70b07f4e..249e6c84 100755 --- a/.github/workflows/smoke/build-relay.sh +++ b/.github/workflows/smoke/build-relay.sh @@ -16,8 +16,10 @@ relay: am_relay: true EOF - export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" - export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}' + # TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs. + # Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4. + export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242" + export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}' HOST="host2" ../genconfig.sh >host2.yml <host3.yml diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index dcd132b0..b23516ee 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,9 +5,15 @@ set -e -x rm -rf ./build mkdir ./build -# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 -# - We could make this better by launching the lighthouse first and then fetching what IP it is. -NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" +# Smoke containers run on a dedicated docker network whose subnet is allocated +# at smoke time, not known at build time. Configs are written with TEST-NET-3 +# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh +# sed the real container IPs in before starting nebula. +# +# Placeholder mapping (last octet == fixed container slot): +# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2, +# 203.0.113.4 -> host3, 203.0.113.5 -> host4. +LIGHTHOUSE_IP="203.0.113.2" ( cd build @@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 9c113e18..aa1cd915 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -6,6 +6,8 @@ set -o pipefail mkdir -p logs +NETWORK="nebula-smoke-relay" + cleanup() { echo echo " *** cleanup" @@ -16,22 +18,53 @@ cleanup() { then docker kill lighthouse1 host2 host3 host4 fi + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT -docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test -docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test -docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test -docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + +docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test + +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x @@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 1c1e3c50..e3863cb5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1" mkdir -p logs +NETWORK="nebula-smoke" + cleanup() { echo echo " *** cleanup" @@ -19,21 +21,51 @@ cleanup() { docker kill lighthouse1 host2 fi vagrant destroy -f + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# This must happen before `vagrant up` rsyncs build/ into the VM for host3. +for f in build/host2.yml build/host3.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test vagrant up vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 @@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T vagrant ssh -c "sudo xargs kill /dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# build/lighthouse1.yml has no IPs to rewrite so it's skipped. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test -docker run --name host3 --rm "$CONTAINER" -config host3.yml -test -docker run --name host4 --rm "$CONTAINER" -config host4.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 # grab tcpdump pcaps for debugging @@ -131,7 +165,13 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then From 8c71f2f3f96660c23ecb2faf0af984796f6890b3 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 21 Apr 2026 10:45:46 -0500 Subject: [PATCH 27/44] FreeBSD tun needs to be non blocking as well (#1666) --- overlay/tun_freebsd.go | 263 +++++++++++++++++++++++++++++------------ 1 file changed, 186 insertions(+), 77 deletions(-) diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2f65b3a4..91c51159 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -10,6 +10,7 @@ import ( "io" "io/fs" "net/netip" + "os" "sync/atomic" "syscall" "time" @@ -93,107 +94,184 @@ type tun struct { routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger - devFd int + + fd int + shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls + shutdownW int // write end of the shutdown pipe; closing this signals shutdown to any blocked reader/writer + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool +} + +// blockOnRead waits until the tun fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *tun) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *tun) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil } func (t *tun) Read(to []byte) (int, error) { - // use readv() to read from the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - // first 4 bytes is protocol family, in network byte order - head := make([]byte, 4) - - iovecs := []syscall.Iovec{ + var head [4]byte + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&to[0], uint64(len(to))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil - } - // fix bytes read number to exclude header - bytesRead := int(n) - if bytesRead < 0 { - return bytesRead, err - } else if bytesRead < 4 { - return 0, err - } else { - return bytesRead - 4, err + for { + n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + bytesRead := int(n) + if bytesRead < 4 { + return 0, nil + } + return bytesRead - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { - // use writev() to write to the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - if len(from) <= 1 { return 0, syscall.EIO } + ipVer := from[0] >> 4 - var head []byte + var head [4]byte // first 4 bytes is protocol family, in network byte order - if ipVer == 4 { - head = []byte{0, 0, 0, syscall.AF_INET} - } else if ipVer == 6 { - head = []byte{0, 0, 0, syscall.AF_INET6} - } else { + switch ipVer { + case 4: + head[3] = syscall.AF_INET + case 6: + head[3] = syscall.AF_INET6 + default: return 0, fmt.Errorf("unable to determine IP version from packet") } - iovecs := []syscall.Iovec{ + + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&from[0], uint64(len(from))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil + for { + n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + return int(n) - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } - - return int(n) - 4, err } func (t *tun) Close() error { - if t.devFd >= 0 { - err := syscall.Close(t.devFd) - if err != nil { + if t.closed.Swap(true) { + return nil + } + + // Closing the write end of the shutdown pipe causes any blocked Poll to + // return with POLLHUP on the shutdown fd, so readers/writers wake up and + // exit with os.ErrClosed. + if t.shutdownW >= 0 { + _ = unix.Close(t.shutdownW) + t.shutdownW = -1 + } + + if t.fd >= 0 { + if err := unix.Close(t.fd); err != nil { t.l.WithError(err).Error("Error closing device") } - t.devFd = -1 + t.fd = -1 + } - c := make(chan struct{}) - go func() { - // destroying the interface can block if a read() is still pending. Do this asynchronously. - defer close(c) - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) - if err == nil { - defer syscall.Close(s) - ifreq := ifreqDestroy{Name: t.deviceBytes()} - err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) - } - if err != nil { - t.l.WithError(err).Error("Error destroying tunnel") - } - }() + if t.shutdownR >= 0 { + _ = unix.Close(t.shutdownR) + t.shutdownR = -1 + } - // wait up to 1 second so we start blocking at the ioctl - select { - case <-c: - case <-time.After(1 * time.Second): + c := make(chan struct{}) + go func() { + // destroying the interface can block if a read() is still pending. Do this asynchronously. + defer close(c) + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err == nil { + defer syscall.Close(s) + ifreq := ifreqDestroy{Name: t.deviceBytes()} + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) } + if err != nil { + t.l.WithError(err).Error("Error destroying tunnel") + } + }() + + // wait up to 1 second so we start blocking at the ioctl + select { + case <-c: + case <-time.After(1 * time.Second): } return nil @@ -209,16 +287,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( var err error deviceName := c.GetString("tun.dev", "") if deviceName != "" { - fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/"+deviceName, os.O_RDWR, 0) } if errors.Is(err, fs.ErrNotExist) || deviceName == "" { // If the device doesn't already exist, request a new one and rename it - fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/tun", os.O_RDWR, 0) } if err != nil { return nil, err } + if err = unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set tun device as nonblocking: %w", err) + } + + // Shutdown pipe lets Close wake any reader/writer blocked in Poll. + var pipeFds [2]int + if err = unix.Pipe2(pipeFds[:], unix.O_CLOEXEC|unix.O_NONBLOCK); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to create shutdown pipe: %w", err) + } + shutdownR, shutdownW := pipeFds[0], pipeFds[1] + + closeOnErr := true + defer func() { + if closeOnErr { + _ = unix.Close(fd) + _ = unix.Close(shutdownR) + _ = unix.Close(shutdownW) + } + }() + // Read the name of the interface var name [16]byte arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} @@ -237,7 +337,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } if ctrlErr != nil { - return nil, err + return nil, ctrlErr } ifName := string(bytes.TrimRight(name[:], "\x00")) @@ -253,8 +353,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } defer syscall.Close(s) - fd := uintptr(s) - var fromName [16]byte var toName [16]byte copy(fromName[:], ifName) @@ -266,7 +364,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } // Set the device name - _ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + _ = ioctl(uintptr(s), syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } t := &tun{ @@ -274,13 +372,24 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, - devFd: fd, + fd: fd, + shutdownR: shutdownR, + shutdownW: shutdownW, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, } err = t.reload(c, true) if err != nil { return nil, err } + closeOnErr = false c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) From 2f4532f1026f78028c380aeb937d0f5baf41eab4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 21 Apr 2026 12:41:10 -0500 Subject: [PATCH 28/44] No more dns globals, proper cleanup on shutdown (#1667) --- dns_server.go | 255 ++++++++++++++++++++++++++++++++++++--------- dns_server_test.go | 220 +++++++++++++++++++++++++++++++++++++- hostmap.go | 4 +- interface.go | 6 +- main.go | 21 +--- 5 files changed, 432 insertions(+), 74 deletions(-) diff --git a/dns_server.go b/dns_server.go index 73576546..75c56f0f 100644 --- a/dns_server.go +++ b/dns_server.go @@ -1,12 +1,14 @@ package nebula import ( + "context" "fmt" "net" "net/netip" "strconv" "strings" "sync" + "sync/atomic" "github.com/gaissmai/bart" "github.com/miekg/dns" @@ -14,32 +16,207 @@ import ( "github.com/slackhq/nebula/config" ) -// This whole thing should be rewritten to use context - -var dnsR *dnsRecords -var dnsServer *dns.Server -var dnsAddr string - -type dnsRecords struct { +type dnsServer struct { sync.RWMutex l *logrus.Logger + ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Lite + + mux *dns.ServeMux + + // enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`. + // Start, Add, and reload consult it so callers don't need to know the + // gating rules. When it toggles off via reload, accumulated records are + // cleared so a later re-enable starts with a fresh map populated from + // new handshakes. + enabled atomic.Bool + + serverMu sync.Mutex + server *dns.Server + // started is closed once `server` has finished binding (or after + // ListenAndServe returns on a bind failure). Stop waits on it before + // calling Shutdown to avoid the miekg/dns "server not started" race + // where a Shutdown that arrives before bind completes is silently + // ignored, leaving the listener running forever. + started chan struct{} + addr string } -func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { - return &dnsRecords{ +// newDnsServerFromConfig builds a dnsServer, applies the initial config, and +// registers a reload callback. The reload callback is registered before the +// initial config is applied, so a SIGHUP can later enable, fix, or disable +// DNS even if the initial application failed. +// +// The dnsServer internally gates on `lighthouse.serve_dns && +// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally, +// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel +// watcher that tears the listener down on nebula shutdown. The returned +// pointer is always non-nil, even on error. +func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { + ds := &dnsServer{ l: l, + ctx: ctx, dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + + c.RegisterReloadCallback(func(c *config.C) { + if err := ds.reload(c, false); err != nil { + l.WithError(err).Error("Failed to reload DNS responder from config") + } + }) + + if err := ds.reload(c, true); err != nil { + return ds, err + } + return ds, nil } -func (d *dnsRecords) Query(q uint16, data string) netip.Addr { +// reload applies the latest config and reconciles the running state with it: +// - enabled toggled on -> spawn a runner +// - enabled toggled off -> stop the runner +// - listen address changed (while running) -> restart on the new address +// - everything else -> no-op +// +// On the initial call it only records configuration; Control.Start is what +// launches the first runner via dnsStart. +func (d *dnsServer) reload(c *config.C, initial bool) error { + wantsDns := c.GetBool("lighthouse.serve_dns", false) + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) + enabled := wantsDns && amLighthouse + newAddr := getDnsServerAddr(c) + + d.serverMu.Lock() + running := d.server + runningStarted := d.started + sameAddr := d.addr == newAddr + d.addr = newAddr + d.enabled.Store(enabled) + d.serverMu.Unlock() + + if initial { + if wantsDns && !amLighthouse { + d.l.Warn("DNS server refusing to run because this host is not a lighthouse.") + } + return nil + } + + if !enabled { + if running != nil { + d.Stop() + } + // Drop any records that accumulated while enabled; a later re-enable + // will repopulate from fresh handshakes. + d.clearRecords() + return nil + } + + if running == nil { + // Was disabled (or never started); bring it up now. + go d.Start() + return nil + } + + if sameAddr { + return nil + } + + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the + // new address. + go d.Start() + return nil +} + +// shutdownServer waits for the server to finish binding (so Shutdown actually +// stops it rather than no-oping) and then shuts it down. +func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) { + if srv == nil { + return + } + if started != nil { + <-started + } + if err := srv.Shutdown(); err != nil { + d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder") + } +} + +// Start binds and serves the DNS responder. Blocks until Stop is called or +// the listener errors. Safe to call when DNS is disabled (returns +// immediately). This is what Control.dnsStart points at. +// +// Must be invoked after the tun device is active so that lighthouse.dns.host +// may bind to a nebula IP. +func (d *dnsServer) Start() { + if !d.enabled.Load() { + return + } + + started := make(chan struct{}) + d.serverMu.Lock() + if d.ctx.Err() != nil { + d.serverMu.Unlock() + return + } + addr := d.addr + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: d.mux, + NotifyStartedFunc: func() { close(started) }, + } + d.server = server + d.started = started + d.serverMu.Unlock() + + // Per-invocation ctx watcher. Exits when Start does, so we don't leak a + // watcher per reload-driven restart. + done := make(chan struct{}) + go func() { + select { + case <-d.ctx.Done(): + d.shutdownServer(server, started, "shutdown") + case <-done: + } + }() + + d.l.WithField("dnsListener", addr).Info("Starting DNS responder") + err := server.ListenAndServe() + close(done) + + // If the listener never bound (bind error) NotifyStartedFunc never fires, + // so close started here to release any Stop caller waiting on it. + select { + case <-started: + default: + close(started) + } + + if err != nil { + d.l.WithError(err).Warn("Failed to run the DNS responder") + } +} + +// Stop shuts down the active server, if any. Idempotent. +func (d *dnsServer) Stop() { + d.serverMu.Lock() + srv := d.server + started := d.started + d.server = nil + d.started = nil + d.serverMu.Unlock() + d.shutdownServer(srv, started, "stop") +} + +func (d *dnsServer) Query(q uint16, data string) netip.Addr { data = strings.ToLower(data) d.RLock() defer d.RUnlock() @@ -57,7 +234,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr { return netip.Addr{} } -func (d *dnsRecords) QueryCert(data string) string { +func (d *dnsServer) QueryCert(data string) string { ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" @@ -80,8 +257,19 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } +// clearRecords drops all DNS records. +func (d *dnsServer) clearRecords() { + d.Lock() + defer d.Unlock() + clear(d.dnsMap4) + clear(d.dnsMap6) +} + // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` -func (d *dnsRecords) Add(host string, addresses []netip.Addr) { +func (d *dnsServer) Add(host string, addresses []netip.Addr) { + if !d.enabled.Load() { + return + } host = strings.ToLower(host) d.Lock() defer d.Unlock() @@ -101,7 +289,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) { } } -func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { +func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { a, _, _ := net.SplitHostPort(addr) b, err := netip.ParseAddr(a) if err != nil { @@ -116,7 +304,7 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return d.myVpnAddrsTable.Contains(b) } -func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: @@ -150,7 +338,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } -func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false @@ -163,21 +351,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(l, cs, hostMap) - - // attach request handler func - dns.HandleFunc(".", dnsR.handleDnsRequest) - - c.RegisterReloadCallback(func(c *config.C) { - reloadDns(l, c) - }) - - return func() { - startDns(l, c) - } -} - func getDnsServerAddr(c *config.C) string { dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. @@ -186,25 +359,3 @@ func getDnsServerAddr(c *config.C) string { } return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } - -func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) - dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} - l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") - err := dnsServer.ListenAndServe() - defer dnsServer.Shutdown() - if err != nil { - l.Errorf("Failed to start server: %s\n ", err.Error()) - } -} - -func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { - l.Debug("No DNS server config change detected") - return - } - - l.Debug("Restarting DNS server") - dnsServer.Shutdown() - go startDns(l, c) -} diff --git a/dns_server_test.go b/dns_server_test.go index 356e5890..c33c0480 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,19 +1,31 @@ package nebula import ( + "context" + "io" + "net" "net/netip" + "strconv" "testing" + "time" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParsequery(t *testing.T) { l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(l, &CertState{}, hostMap) + ds := &dnsServer{ + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + } + ds.enabled.Store(true) addrs := []netip.Addr{ netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5"), @@ -71,3 +83,209 @@ func Test_getDnsServerAddr(t *testing.T) { } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) } + +func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { + t.Helper() + l := logrus.New() + l.Out = io.Discard + ds := &dnsServer{ + l: l, + ctx: context.Background(), + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: &HostMap{}, + } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + return ds, config.NewC(l) +} + +func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { + c.Settings["lighthouse"] = map[string]any{ + "am_lighthouse": amLighthouse, + "serve_dns": serveDns, + "dns": map[string]any{ + "host": host, + "port": port, + }, + } +} + +func TestDnsServer_reload_initial_disabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, false) + + require.NoError(t, ds.reload(c, true)) + assert.False(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_enabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + assert.True(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + // initial never starts a runner; that's Control.Start's job + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", false, true) + + require.NoError(t, ds.reload(c, true)) + // Wants DNS but isn't a lighthouse: gated off, no runner. + assert.False(t, ds.enabled.Load()) +} + +func TestDnsServer_reload_sameAddr_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + // No server running yet, no addr change. Reload should not spawn anything. + require.NoError(t, ds.reload(c, false)) + assert.True(t, ds.enabled.Load()) + assert.Nil(t, ds.server) +} + +func TestDnsServer_StartStop_lifecycle(t *testing.T) { + // Bind to a real (random) UDP port so we exercise the actual + // ListenAndServe + Shutdown plumbing including the started-chan race fix. + port := freeUDPPort(t) + + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) + + ds.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { + // Stop called immediately after Start should not deadlock even if bind + // hasn't completed yet. This exercises the started-chan close-on-bind-fail + // path: by binding to an obviously bad port (privileged) we get a fast + // bind error before NotifyStartedFunc fires. + ds, c := newTestDnsServer(t) + // Use a port that should fail to bind (negative would be invalid, use a + // host that won't resolve to ensure listenUDP fails quickly). + setDnsConfig(c, "256.256.256.256", "53", true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + // Give Start a moment to attempt the bind and fail. + select { + case <-done: + // Bind failed and Start returned; Stop should be a no-op. + case <-time.After(time.Second): + t.Fatal("Start did not return after a bad bind") + } + + stopped := make(chan struct{}) + go func() { + ds.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung after a failed bind") + } +} + +func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { + port := freeUDPPort(t) + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + startReturned := make(chan struct{}) + go func() { + ds.Start() + close(startReturned) + }() + waitForBind(t, ds) + + // Toggle serve_dns off; reload should shut the running server down. + setDnsConfig(c, "127.0.0.1", port, true, false) + require.NoError(t, ds.reload(c, false)) + select { + case <-startReturned: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled DNS") + } + assert.False(t, ds.enabled.Load()) +} + +func freeUDPPort(t *testing.T) string { + t.Helper() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + port := conn.LocalAddr().(*net.UDPAddr).Port + require.NoError(t, conn.Close()) + return strconv.Itoa(port) +} + +func waitForBind(t *testing.T, ds *dnsServer) { + t.Helper() + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("timed out waiting for condition") +} diff --git a/hostmap.go b/hostmap.go index 7e2939e0..25181d83 100644 --- a/hostmap.go +++ b/hostmap.go @@ -604,9 +604,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { - if f.serveDns { + if f.dnsServer != nil { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) + f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) diff --git a/interface.go b/interface.go index 9e7a98a9..481b1d4d 100644 --- a/interface.go +++ b/interface.go @@ -29,7 +29,7 @@ type InterfaceConfig struct { pki *PKI Cipher string Firewall *Firewall - ServeDns bool + DnsServer *dnsServer HandshakeManager *HandshakeManager lightHouse *LightHouse connectionManager *connectionManager @@ -57,7 +57,7 @@ type Interface struct { firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager - serveDns bool + dnsServer *dnsServer createTime time.Time lightHouse *LightHouse myBroadcastAddrsTable *bart.Lite @@ -175,7 +175,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { outside: c.Outside, inside: c.Inside, firewall: c.Firewall, - serveDns: c.ServeDns, + dnsServer: c.DnsServer, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, diff --git a/main.go b/main.go index 8adc2921..0ac63dfa 100644 --- a/main.go +++ b/main.go @@ -215,13 +215,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + if err != nil { + l.WithError(err).Warn("Failed to start DNS responder") } ifConfig := &InterfaceConfig{ @@ -230,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Outside: udpConns[0], pki: pki, Firewall: fw, - ServeDns: serveDns, + DnsServer: ds, HandshakeManager: handshakeManager, connectionManager: connManager, lightHouse: lightHouse, @@ -280,13 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg attachCommands(l, c, ssh, ifce) - // Start DNS server last to allow using the nebula IP as lighthouse.dns.host - var dnsStart func() - if lightHouse.amLighthouse && serveDns { - l.Debugln("Starting dns server") - dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) - } - return &Control{ state: StateReady, f: ifce, @@ -295,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg cancel: cancel, sshStart: sshStart, statsStart: statsStart, - dnsStart: dnsStart, + dnsStart: ds.Start, lighthouseStart: lightHouse.StartUpdateWorker, connectionManagerStart: connManager.Start, }, nil From 8c50fc3f60f97203090b028d6c26bbef1348b623 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 21 Apr 2026 13:19:54 -0500 Subject: [PATCH 29/44] Plug the conntrack cache ticker leak and nebula-service log.Fatal calls (#1669) --- cmd/nebula-service/main.go | 13 ++++++++----- cmd/nebula-service/service.go | 18 ++++++++---------- firewall/cache.go | 17 ++++++++++++----- interface.go | 6 ++++-- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index aaec80f7..021e36fa 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -50,9 +50,15 @@ func main() { os.Exit(0) } + l := logrus.New() + l.Out = os.Stdout + if *serviceFlag != "" { - doService(configPath, configTest, Build, serviceFlag) - os.Exit(1) + if err := doService(configPath, configTest, Build, serviceFlag); err != nil { + l.WithError(err).Error("Service command failed") + os.Exit(1) + } + return } if *configPath == "" { @@ -61,9 +67,6 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout - c := config.NewC(l) err := c.Load(*configPath) if err != nil { diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index a54fb0f3..1f45f95b 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -57,11 +57,11 @@ func fileExists(filename string) bool { return true } -func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { +func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { ex, err := os.Executable() if err != nil { - panic(err) + return err } *configPath = filepath.Dir(ex) + "/config.yaml" if !fileExists(*configPath) { @@ -88,13 +88,13 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use s, err := service.New(prg, svcConfig) if err != nil { - log.Fatal(err) + return err } errs := make(chan error, 5) logger, err = s.Logger(errs) if err != nil { - log.Fatal(err) + return err } go func() { @@ -109,18 +109,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * switch *serviceFlag { case "run": - err = s.Run() - if err != nil { + if err := s.Run(); err != nil { // Route any errors to the system logger logger.Error(err) } default: - err := service.Control(s, *serviceFlag) - if err != nil { + if err := service.Control(s, *serviceFlag); err != nil { log.Printf("Valid actions: %q\n", service.ControlAction) - log.Fatal(err) + return err } - return } + return nil } diff --git a/firewall/cache.go b/firewall/cache.go index 71b83f43..a4ffc100 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "sync/atomic" "time" @@ -18,7 +19,7 @@ type ConntrackCacheTicker struct { cache ConntrackCache } -func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { +func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } @@ -27,15 +28,21 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { cache: ConntrackCache{}, } - go c.tick(d) + go c.tick(ctx, d) return c } -func (c *ConntrackCacheTicker) tick(d time.Duration) { +func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) { + t := time.NewTicker(d) + defer t.Stop() for { - time.Sleep(d) - c.cacheTick.Add(1) + select { + case <-ctx.Done(): + return + case <-t.C: + c.cacheTick.Add(1) + } } } diff --git a/interface.go b/interface.go index 481b1d4d..6d040884 100644 --- a/interface.go +++ b/interface.go @@ -85,6 +85,7 @@ type Interface struct { conntrackCacheTimeout time.Duration + ctx context.Context writers []udp.Conn readers []io.ReadWriteCloser wg sync.WaitGroup @@ -170,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { cs := c.pki.getCertState() ifce := &Interface{ + ctx: ctx, pki: c.pki, hostMap: c.HostMap, outside: c.Outside, @@ -303,7 +305,7 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} @@ -328,7 +330,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) for { n, err := reader.Read(packet) From 32a7c044985996d83d39313cb2b015cd987d1459 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 21 Apr 2026 16:32:48 -0400 Subject: [PATCH 30/44] Return NODATA instead of NXDOMAIN for missing record types (#1668) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DNS responder was setting RCODE=NXDOMAIN (Name Error) any time the answer section was empty, including for names that exist in the lighthouse but lack a record of the requested type (e.g. an AAAA query for a v4-only host). Per RFC 2308 §2.1, NXDOMAIN means "the domain referred to by the QNAME does not exist", and per RFC 2308 §2.2 a name that exists with no record of the requested type must be answered with RCODE=NOERROR and an empty answer section (NODATA). The practical fallout: busybox ping in Alpine issues AAAA first, treats NXDOMAIN as a hard failure, and never falls through to A. Returning NODATA lets the resolver continue to the A query as it should. Track whether any queried A/AAAA name is known in either map and only set RcodeNameError when no queried name exists at all. --- dns_server.go | 30 ++++++++++++++++++++++-------- dns_server_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/dns_server.go b/dns_server.go index 75c56f0f..8af88b52 100644 --- a/dns_server.go +++ b/dns_server.go @@ -216,22 +216,28 @@ func (d *dnsServer) Stop() { d.shutdownServer(srv, started, "stop") } -func (d *dnsServer) Query(q uint16, data string) netip.Addr { +// Query returns the address for the given name and query type. The second +// return value reports whether the name is known at all (in either A or AAAA), +// which lets callers distinguish NODATA from NXDOMAIN. +func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) { data = strings.ToLower(data) d.RLock() defer d.RUnlock() + addr4, haveV4 := d.dnsMap4[data] + addr6, haveV6 := d.dnsMap6[data] + nameExists := haveV4 || haveV6 switch q { case dns.TypeA: - if r, ok := d.dnsMap4[data]; ok { - return r + if haveV4 { + return addr4, nameExists } case dns.TypeAAAA: - if r, ok := d.dnsMap6[data]; ok { - return r + if haveV6 { + return addr6, nameExists } } - return netip.Addr{} + return netip.Addr{}, nameExists } func (d *dnsServer) QueryCert(data string) string { @@ -305,12 +311,20 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { + // Per RFC 2308 §2.2, a name that exists but has no record of the requested + // type must be answered with NOERROR and an empty answer section (NODATA), + // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not + // exist at all. + anyNameExists := false for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] d.l.Debugf("Query for %s %s", qType, q.Name) - ip := d.Query(q.Qtype, q.Name) + ip, nameExists := d.Query(q.Qtype, q.Name) + if nameExists { + anyNameExists = true + } if ip.IsValid() { rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { @@ -333,7 +347,7 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } - if len(m.Answer) == 0 { + if len(m.Answer) == 0 && !anyNameExists { m.Rcode = dns.RcodeNameError } } diff --git a/dns_server_test.go b/dns_server_test.go index c33c0480..ef8a5a64 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -33,18 +33,43 @@ func TestParsequery(t *testing.T) { netip.MustParseAddr("fd01::25"), } ds.Add("test.com.com", addrs) + ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")}) + ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")}) m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) m = &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeAAAA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // A known name with no record of the requested type should return NODATA + // (NOERROR with empty answer), not NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("v4only.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + m = &dns.Msg{} + m.SetQuestion("v6only.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // An unknown name should still return NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("unknown.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) } func Test_getDnsServerAddr(t *testing.T) { From e753e6e93c215780aef992d0481d894f19f5b9fe Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 21 Apr 2026 16:33:32 -0400 Subject: [PATCH 31/44] Immediate Lighthouse update after reconfig/reconnect (#1645) --- e2e/handshakes_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++ handshake_ix.go | 10 ++++++ lighthouse.go | 18 +++++++++- 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 67b166b1..7729465b 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -1369,6 +1369,81 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { theirControl.Stop() } +func TestLighthouseUpdateOnReload(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + // Create the lighthouse + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}}) + + // Create a client with NO lighthouse configured and a long update interval. + // The initial SendUpdate at startup will be a no-op since no lighthouses are known. + myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{ + "lighthouse": m{ + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + }, + }) + + r := router.NewR(t, lhControl, myControl) + defer r.RenderFlow() + + lhControl.Start() + myControl.Start() + + // Drain any startup packets (there should be none meaningful) + r.FlushAll() + + // Verify lighthouse has no knowledge of the client + assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + // Build a new config that adds the lighthouse + newSettings := make(m) + for k, v := range myConfig.Settings { + newSettings[k] = v + } + newSettings["static_host_map"] = m{ + lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()}, + } + newSettings["lighthouse"] = m{ + "hosts": []any{lhVpnIpNet[0].Addr().String()}, + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + } + newCfg, err := yaml.Marshal(newSettings) + require.NoError(t, err) + + // Reload the config. The lighthouse.hosts change triggers TriggerUpdate, + // which wakes the update worker. It calls SendUpdate, initiating a + // handshake to the new lighthouse and caching the HostUpdateNotification. + require.NoError(t, myConfig.ReloadConfigString(string(newCfg))) + + // Route until the lighthouse receives the HostUpdateNotification. + // This covers: handshake stage 1, stage 2, then the cached update. + done := make(chan struct{}) + go func() { + r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for lighthouse update after config reload") + } + + // Verify lighthouse now has the client's addresses + assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + r.RenderHostmaps("Final hostmaps", lhControl, myControl) + lhControl.Stop() + myControl.Stop() +} + func TestGoodHandshakeUnsafeDest(t *testing.T) { unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) diff --git a/handshake_ix.go b/handshake_ix.go index 4e04f450..f081eb8c 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -462,6 +462,11 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } + return } @@ -674,5 +679,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe hostinfo.remotes.RefreshFromHandshake(vpnAddrs) f.metricHandshakes.Update(duration) + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } + return false } diff --git a/lighthouse.go b/lighthouse.go index 36eb9aa0..50140e9e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -69,7 +69,8 @@ type LightHouse struct { // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan netip.Addr + updateTrigger chan struct{} + queryChan chan netip.Addr calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote @@ -105,6 +106,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } @@ -316,6 +318,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") + lh.TriggerUpdate() } } @@ -841,11 +844,24 @@ func (lh *LightHouse) StartUpdateWorker() { return case <-clockSource.C: continue + case <-lh.updateTrigger: + continue } } }() } +// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking +// operation intended to be called after a handshake completes with a lighthouse, +// so the lighthouse has our current addresses without waiting for the next +// periodic update. +func (lh *LightHouse) TriggerUpdate() { + select { + case lh.updateTrigger <- struct{}{}: + default: + } +} + func (lh *LightHouse) SendUpdate() { var v4 []*V4AddrPort var v6 []*V6AddrPort From 2a1cc62001caccbbb875e03da90b671c9ac5ec16 Mon Sep 17 00:00:00 2001 From: Guy Nesher Date: Wed, 22 Apr 2026 20:42:14 +0300 Subject: [PATCH 32/44] fix: guard QueryCert against panic on short/empty QNAME (#1635) * fix: guard QueryCert against panic on short/empty QNAME QueryCert slices data[:len(data)-1] to strip a trailing dot, which panics when data is empty (slice bounds [:-1]). Add a length check to return early for inputs shorter than a minimal valid "x." form. While miekg/dns currently rejects wire-format packets that would produce an empty QNAME, the Nebula code should not rely on library behavior for crash safety. Made-with: Cursor * fix merge conflicts --------- Co-authored-by: JackDoan --- dns_server.go | 3 +++ dns_server_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/dns_server.go b/dns_server.go index 8af88b52..5b12b922 100644 --- a/dns_server.go +++ b/dns_server.go @@ -241,6 +241,9 @@ func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) { } func (d *dnsServer) QueryCert(data string) string { + if len(data) < 2 { + return "" + } ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" diff --git a/dns_server_test.go b/dns_server_test.go index ef8a5a64..e09d3fa9 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -16,6 +16,19 @@ import ( "github.com/stretchr/testify/require" ) +type stubDNSWriter struct{} + +func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (stubDNSWriter) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353} +} +func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil } +func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil } +func (stubDNSWriter) Close() error { return nil } +func (stubDNSWriter) TsigStatus() error { return nil } +func (stubDNSWriter) TsigTimersOnly(bool) {} +func (stubDNSWriter) Hijack() {} + func TestParsequery(t *testing.T) { l := logrus.New() hostMap := &HostMap{} @@ -70,6 +83,19 @@ func TestParsequery(t *testing.T) { ds.parseQuery(m, nil) assert.Empty(t, m.Answer) assert.Equal(t, dns.RcodeNameError, m.Rcode) + + // short lookups should not fail + m = &dns.Msg{} + m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) + + m = &dns.Msg{} + m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) } func Test_getDnsServerAddr(t *testing.T) { From 5f00ab4b745ba04db458c119f876bcc9cb104e24 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 22 Apr 2026 17:18:06 -0500 Subject: [PATCH 33/44] Fix e2e tests writing after the tester tun is closed causing a panic (#1681) --- udp/udp_tester.go | 57 +++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5db72555..388b17d0 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -7,7 +7,7 @@ import ( "io" "net/netip" "os" - "sync/atomic" + "sync" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -37,8 +37,16 @@ type TesterConn struct { RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - closed atomic.Bool - l *logrus.Logger + // done is closed exactly once by Close. Senders select on it so they + // never race with a channel close; readers exit when it fires. The + // packet channels are intentionally never closed - that was the source + // of `send on closed channel` panics when a WriteTo/Send from another + // goroutine passed the close check and reached the send just after + // Close ran. + done chan struct{} + closeOnce sync.Once + + l *logrus.Logger } func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { @@ -46,6 +54,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), + done: make(chan struct{}), l: l, }, nil } @@ -54,10 +63,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - if u.closed.Load() { - return - } - h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) @@ -68,7 +73,10 @@ func (u *TesterConn) Send(packet *Packet) { WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } - u.RxPackets <- packet + select { + case <-u.done: + case u.RxPackets <- packet: + } } // Get will pull a UdpPacket from the transmit queue @@ -76,7 +84,12 @@ func (u *TesterConn) Send(packet *Packet) { // packets were ingested from the tun side (in most cases), you can send them with Tun.Send func (u *TesterConn) Get(block bool) *Packet { if block { - return <-u.TxPackets + select { + case <-u.done: + return nil + case p := <-u.TxPackets: + return p + } } select { @@ -92,10 +105,6 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - if u.closed.Load() { - return io.ErrClosedPipe - } - p := &Packet{ Data: make([]byte, len(b), len(b)), From: u.Addr, @@ -103,17 +112,22 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { } copy(p.Data, b) - u.TxPackets <- p - return nil + select { + case <-u.done: + return io.ErrClosedPipe + case u.TxPackets <- p: + return nil + } } func (u *TesterConn) ListenOut(r EncReader) error { for { - p, ok := <-u.RxPackets - if !ok { + select { + case <-u.done: return os.ErrClosed + case p := <-u.RxPackets: + r(p.From, p.Data) } - r(p.From, p.Data) } } @@ -137,9 +151,8 @@ func (u *TesterConn) Rebind() error { } func (u *TesterConn) Close() error { - if u.closed.CompareAndSwap(false, true) { - close(u.RxPackets) - close(u.TxPackets) - } + u.closeOnce.Do(func() { + close(u.done) + }) return nil } From db9218b0beac85b7539611484a7f406c9feb10a3 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 23 Apr 2026 13:51:15 -0500 Subject: [PATCH 34/44] Another shot at the flakey smoke test (#1688) --- .github/workflows/smoke/smoke.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 62ceafe1..cad9dde7 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -82,7 +82,7 @@ docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 & -docker exec host4 ncat -nkluv 0.0.0.0 4000 & +docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & @@ -155,11 +155,11 @@ echo " *** Testing conntrack" echo set -x -# host2 speaking to host4 on UDP 4000 should allow it to reply, when firewall rules would normally not permit this -docker exec host2 sh -c "/usr/bin/echo host2 | ncat -nuv 192.168.100.4 4000" -docker exec host2 ncat -e '/usr/bin/echo helloagainfromhost2' -nkluv 0.0.0.0 4000 & -sleep 1 -docker exec host4 sh -c "/usr/bin/echo host4 | ncat -nuv 192.168.100.2 4000" +# host4's outbound firewall only allows ICMP to the lighthouse, so host4 +# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000, +# conntrack must let host4's listener reply on that flow. If it doesn't, +# the echo back from host4 never reaches host2. +docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4 docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' From db85d61c2331f1f1508b1f9be11178d0ada207a5 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:53:52 -0400 Subject: [PATCH 35/44] SSH handshake in goroutine and defer close (#1640) * SSH handshake in goroutine and defer close --- sshd/server.go | 83 ++++++++++++++++++++++++------------------------- sshd/session.go | 16 +++++----- 2 files changed, 47 insertions(+), 52 deletions(-) diff --git a/sshd/server.go b/sshd/server.go index a8b60ba7..4b5cc3e0 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -2,10 +2,10 @@ package sshd import ( "bytes" + "context" "errors" "fmt" "net" - "sync" "github.com/armon/go-radix" "github.com/sirupsen/logrus" @@ -27,20 +27,21 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Locks the conns/counter to avoid concurrent map access - connsLock sync.Mutex - conns map[int]*session - counter int + // Call the cancel() function to stop all active sessions + ctx context.Context + cancel func() } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { + ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), - conns: make(map[int]*session), + ctx: ctx, + cancel: cancel, } cc := ssh.CertChecker{ @@ -175,44 +176,44 @@ func (s *SSHServer) run() { } return } - - conn, chans, reqs, err := ssh.NewServerConn(c, s.config) - fp := "" - if conn != nil { - fp = conn.Permissions.Extensions["fp"] - } - - if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + go func(c net.Conn) { + // NewServerConn may block while waiting for the client to complete the handshake. + // Ensure that a bad client doesn't hurt us by checking for the parent context + // cancellation before calling NewServerConn, and forcing the socket to close when + // the context is cancelled. + sessionContext, sessionCancel := context.WithCancel(s.ctx) + go func() { + <-sessionContext.Done() + c.Close() + }() + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + fp := "" if conn != nil { - l = l.WithField("sshUser", conn.User()) - conn.Close() + fp = conn.Permissions.Extensions["fp"] } - if fp != "" { - l = l.WithField("sshFingerprint", fp) + + if err != nil { + l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + if conn != nil { + l = l.WithField("sshUser", conn.User()) + conn.Close() + } + if fp != "" { + l = l.WithField("sshFingerprint", fp) + } + l.Warn("failed to handshake") + sessionCancel() + return } - l.Warn("failed to handshake") - continue - } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.WithField("sshUser", conn.User()) + l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") - session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) - s.connsLock.Lock() - s.counter++ - counter := s.counter - s.conns[counter] = session - s.connsLock.Unlock() + NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session")) - go ssh.DiscardRequests(reqs) - go func() { - <-session.exitChan - s.l.WithField("id", counter).Debug("closing conn") - s.connsLock.Lock() - delete(s.conns, counter) - s.connsLock.Unlock() - }() + go ssh.DiscardRequests(reqs) + + }(c) } } @@ -226,9 +227,5 @@ func (s *SSHServer) Stop() { } func (s *SSHServer) closeSessions() { - s.connsLock.Lock() - for _, c := range s.conns { - c.Close() - } - s.connsLock.Unlock() + s.cancel() } diff --git a/sshd/session.go b/sshd/session.go index 87cc216f..39c81bd0 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -17,15 +17,15 @@ type session struct { c *ssh.ServerConn term *term.Terminal commands *radix.Tree - exitChan chan bool + cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, c: conn, - exitChan: make(chan bool), + cancel: cancel, } s.commands.Insert("logout", &Command{ @@ -42,6 +42,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New } func (s *session) handleChannels(chans <-chan ssh.NewChannel) { + defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") @@ -100,7 +101,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { if err != nil { s.l.WithError(err).Info("Error handling ssh session requests") - s.Close() return } } @@ -123,12 +123,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal { return "", 0, false } - go s.handleInput(channel) + go s.handleInput() return term } -func (s *session) handleInput(channel ssh.Channel) { - defer s.Close() +func (s *session) handleInput() { w := &stringWriter{w: s.term} for { line, err := s.term.ReadLine() @@ -170,10 +169,9 @@ func (s *session) dispatchCommand(line string, w StringWriter) { } _ = execCommand(c, args[1:], w) - return } func (s *session) Close() { s.c.Close() - s.exitChan <- true + s.cancel() } From 5f890dbc3410338c38813c1b8a823c2cee74bb47 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 24 Apr 2026 13:12:42 -0500 Subject: [PATCH 36/44] noise: only type-assert once (#1691) --- noise.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/noise.go b/noise.go index 57990a79..0491da17 100644 --- a/noise.go +++ b/noise.go @@ -15,14 +15,12 @@ type endianness interface { var noiseEndianness endianness = binary.BigEndian type NebulaCipherState struct { - c noise.Cipher - //k [32]byte - //n uint64 + c cipher.AEAD } func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - return &NebulaCipherState{c: s.Cipher()} - + x := s.Cipher() + return &NebulaCipherState{c: x.(cipher.AEAD)} } // EncryptDanger encrypts and authenticates a given payload. @@ -46,7 +44,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) - out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) + out = s.c.Seal(out, nb, plaintext, ad) //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) return out, nil } else { @@ -61,7 +59,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) - return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) + return s.c.Open(out, nb, ciphertext, ad) } else { return []byte{}, nil } @@ -69,7 +67,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, func (s *NebulaCipherState) Overhead() int { if s != nil { - return s.c.(cipher.AEAD).Overhead() + return s.c.Overhead() } return 0 } From d0f02ba87343066fa83cdcb17a440ad8a1025fcc Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 27 Apr 2026 09:41:47 -0500 Subject: [PATCH 37/44] Switch to slog, remove logrus (#1672) --- .golangci.yaml | 14 + bits.go | 38 +- cmd/nebula-service/logs_generic.go | 13 +- cmd/nebula-service/logs_windows.go | 110 +++-- cmd/nebula-service/main.go | 19 +- cmd/nebula-service/service.go | 16 +- cmd/nebula/main.go | 17 +- cmd/nebula/notify_linux.go | 15 +- cmd/nebula/notify_notlinux.go | 4 +- config/config.go | 21 +- connection_manager.go | 122 ++--- connection_manager_test.go | 33 +- connection_state.go | 3 +- control.go | 14 +- control_test.go | 3 +- dns_server.go | 23 +- dns_server_test.go | 12 +- e2e/handshakes_test.go | 35 +- e2e/helpers_test.go | 44 +- examples/config.yml | 22 +- examples/go_service/main.go | 5 +- firewall.go | 65 +-- firewall/cache.go | 13 +- firewall/cache_test.go | 69 +++ firewall_test.go | 100 ++-- go.mod | 1 - go.sum | 2 - handshake_ix.go | 541 +++++++++++++-------- handshake_manager.go | 157 +++--- hostmap.go | 76 +-- hostmap_test.go | 2 +- inside.go | 119 +++-- interface.go | 66 +-- lighthouse.go | 228 ++++++--- logger.go | 45 -- logging/logger.go | 233 +++++++++ logging/logger_bench_test.go | 90 ++++ main.go | 39 +- outside.go | 157 +++--- test/tun.go => overlay/overlaytest/noop.go | 7 +- overlay/route.go | 9 +- overlay/route_test.go | 4 +- overlay/tun.go | 8 +- overlay/tun_android.go | 8 +- overlay/tun_darwin.go | 17 +- overlay/tun_disabled.go | 21 +- overlay/tun_freebsd.go | 20 +- overlay/tun_ios.go | 8 +- overlay/tun_linux.go | 72 +-- overlay/tun_netbsd.go | 16 +- overlay/tun_openbsd.go | 16 +- overlay/tun_tester.go | 13 +- overlay/tun_windows.go | 16 +- overlay/user.go | 4 +- pki.go | 18 +- pki_hup_benchmark_test.go | 2 +- punchy.go | 14 +- punchy_test.go | 173 ++++++- relay_manager.go | 165 ++++--- remote_list.go | 18 +- service/service_test.go | 5 +- ssh.go | 83 ++-- sshd/server.go | 37 +- sshd/session.go | 14 +- stats.go | 24 +- test/logger.go | 68 ++- udp/udp_android.go | 5 +- udp/udp_bsd.go | 5 +- udp/udp_darwin.go | 10 +- udp/udp_generic.go | 8 +- udp/udp_linux.go | 24 +- udp/udp_netbsd.go | 5 +- udp/udp_rio_windows.go | 14 +- udp/udp_tester.go | 18 +- udp/udp_windows.go | 7 +- util/error.go | 27 +- util/error_test.go | 68 +-- 77 files changed, 2299 insertions(+), 1338 deletions(-) create mode 100644 firewall/cache_test.go delete mode 100644 logger.go create mode 100644 logging/logger.go create mode 100644 logging/logger_bench_test.go rename test/tun.go => overlay/overlaytest/noop.go (68%) diff --git a/.golangci.yaml b/.golangci.yaml index bd82a952..be0513d4 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,7 +2,21 @@ version: "2" linters: default: none enable: + - sloglint - testifylint + settings: + sloglint: + # Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and + # the package-level slog equivalents. Use l.Log(ctx, level, ...) for + # custom levels instead of LogAttrs when you can. + # + # LogAttrs is also flagged by this rule because it takes ...slog.Attr; + # the few legitimate sites (where attrs is built up as a []slog.Attr) + # carry a //nolint:sloglint with rationale. + kv-only: true + # no-mixed-args is on by default: forbids mixing kv and attrs in one call. + # discard-handler is on by default (since Go 1.24): suggests + # slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil). exclusions: generated: lax presets: diff --git a/bits.go b/bits.go index af11cc48..5c8f902b 100644 --- a/bits.go +++ b/bits.go @@ -1,8 +1,10 @@ package nebula import ( + "context" + "log/slog" + "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" ) type Bits struct { @@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits { return b } -func (b *Bits) Check(l *logrus.Logger, i uint64) bool { +func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true @@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool { } // Not within the window - if l.Level >= logrus.DebugLevel { - l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("rejected a packet (top)", + "current", b.current, + "incoming", i, + ) } return false } -func (b *Bits) Update(l *logrus.Logger, i uint64) bool { +func (b *Bits) Update(l *slog.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter @@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // Check to see if it's a duplicate if i > b.current-b.length || i < b.length && b.current < b.length { if b.current == i || b.bits[i%b.length] == true { - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "duplicate", + ) } b.dupeCounter.Inc(1) return false @@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // In all other cases, fail and don't change current. b.outOfWindowCounter.Inc(1) - if l.Level >= logrus.DebugLevel { - l.WithField("accepted", false). - WithField("currentCounter", b.current). - WithField("incomingCounter", i). - WithField("reason", "nonsense"). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "nonsense", + ) } return false } diff --git a/cmd/nebula-service/logs_generic.go b/cmd/nebula-service/logs_generic.go index 3b7cdd1c..cc06b4c5 100644 --- a/cmd/nebula-service/logs_generic.go +++ b/cmd/nebula-service/logs_generic.go @@ -3,8 +3,15 @@ package main -import "github.com/sirupsen/logrus" +import ( + "log/slog" + "os" -func HookLogger(l *logrus.Logger) { - // Do nothing, let the logs flow to stdout/stderr + "github.com/slackhq/nebula/logging" +) + +// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows +// platforms have no special sink to integrate with. +func newPlatformLogger() *slog.Logger { + return logging.NewLogger(os.Stdout) } diff --git a/cmd/nebula-service/logs_windows.go b/cmd/nebula-service/logs_windows.go index af6480ef..ca0a55c5 100644 --- a/cmd/nebula-service/logs_windows.go +++ b/cmd/nebula-service/logs_windows.go @@ -1,54 +1,86 @@ package main import ( - "fmt" - "io/ioutil" - "os" + "context" + "log/slog" + "strings" + "sync" - "github.com/kardianos/service" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer -// logrus output will be discarded -func HookLogger(l *logrus.Logger) { - l.AddHook(newLogHook(logger)) - l.SetOutput(ioutil.Discard) +// newPlatformLogger returns a *slog.Logger that routes every log record +// through the Windows service logger so records end up in the Windows +// Event Log. All the heavy lifting (level management, format swap, +// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler; +// this file only contributes: +// +// - an io.Writer that forwards each formatted line to the service +// logger at the current record's Event Log severity, and +// - a thin severityTag that embeds *logging.Handler and overrides +// only Handle / WithAttrs / WithGroup, so Event Viewer's severity +// column and severity-based filters keep working the way they did +// before the slog migration. +// +// Format (text vs json) is carried by the embedded *logging.Handler, so +// logging.format: json in config still produces JSON lines in Event +// Viewer, same as the pre-slog logrus setup. +func newPlatformLogger() *slog.Logger { + w := &eventLogWriter{} + return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w}) } -type logHook struct { - sl service.Logger +// eventLogWriter forwards slog-formatted lines to the Windows service +// logger at the severity most recently stashed by severityTag.Handle. +// The mutex serializes the stash + inner.Handle + Write cycle per record +// across all concurrent goroutines; slog's builtin text/json handlers +// each hold their own mutex around Write, but that only protects the +// Write call itself, not our stash-then-handle sequence. +type eventLogWriter struct { + mu sync.Mutex + level slog.Level } -func newLogHook(sl service.Logger) *logHook { - return &logHook{sl: sl} -} - -func (h *logHook) Fire(entry *logrus.Entry) error { - line, err := entry.String() - if err != nil { - fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) - return err - } - - switch entry.Level { - case logrus.PanicLevel: - return h.sl.Error(line) - case logrus.FatalLevel: - return h.sl.Error(line) - case logrus.ErrorLevel: - return h.sl.Error(line) - case logrus.WarnLevel: - return h.sl.Warning(line) - case logrus.InfoLevel: - return h.sl.Info(line) - case logrus.DebugLevel: - return h.sl.Info(line) +func (w *eventLogWriter) Write(p []byte) (int, error) { + line := strings.TrimRight(string(p), "\n") + switch { + case w.level >= slog.LevelError: + return len(p), logger.Error(line) + case w.level >= slog.LevelWarn: + return len(p), logger.Warning(line) default: - return nil + return len(p), logger.Info(line) } } -func (h *logHook) Levels() []logrus.Level { - return logrus.AllLevels +// severityTag embeds *logging.Handler to pick up everything it does for +// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat, +// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup +// so each record's slog.Level is stashed on the writer before formatting +// and so derived handlers stay wrapped as severityTag rather than +// downgrading to bare *logging.Handler. +type severityTag struct { + *logging.Handler + w *eventLogWriter +} + +func (s *severityTag) Handle(ctx context.Context, r slog.Record) error { + s.w.mu.Lock() + defer s.w.mu.Unlock() + s.w.level = r.Level + return s.Handler.Handle(ctx, r) +} + +func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return s + } + return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w} +} + +func (s *severityTag) WithGroup(name string) slog.Handler { + if name == "" { + return s + } + return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w} } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 021e36fa..19fb3a9f 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -50,12 +50,11 @@ func main() { os.Exit(0) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) if *serviceFlag != "" { if err := doService(configPath, configTest, Build, serviceFlag); err != nil { - l.WithError(err).Error("Service command failed") + l.Error("Service command failed", "error", err) os.Exit(1) } return @@ -74,6 +73,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -90,7 +99,7 @@ func main() { go ctrl.ShutdownBlock() if err := wait(); err != nil { - l.WithError(err).Error("Nebula stopped due to fatal error") + l.Error("Nebula stopped due to fatal error", "error", err) os.Exit(2) } diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 1f45f95b..6551ceb4 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -7,9 +7,9 @@ import ( "path/filepath" "github.com/kardianos/service" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" ) var logger service.Logger @@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") - l := logrus.New() - HookLogger(l) + l := newPlatformLogger() c := config.NewC(l) err := c.Load(*p.configPath) @@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error { return fmt.Errorf("failed to load config: %s", err) } + if err := logging.ApplyConfig(l, c); err != nil { + return fmt.Errorf("failed to apply logging config: %s", err) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err @@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * // Here are what the different loggers are doing: // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) - // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use + // - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger s, err := service.New(prg, svcConfig) if err != nil { return err diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index f29f4537..d7f0de93 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -55,8 +55,7 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) c := config.NewC(l) err := c.Load(*configPath) @@ -65,6 +64,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -82,7 +91,7 @@ func main() { notifyReady(l) if err := wait(); err != nil { - l.WithError(err).Error("Nebula stopped due to fatal error") + l.Error("Nebula stopped due to fatal error", "error", err) os.Exit(2) } diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go index 8c3dca55..965986a9 100644 --- a/cmd/nebula/notify_linux.go +++ b/cmd/nebula/notify_linux.go @@ -1,11 +1,10 @@ package main import ( + "log/slog" "net" "os" "time" - - "github.com/sirupsen/logrus" ) // SdNotifyReady tells systemd the service is ready and dependent services can now be started @@ -13,30 +12,30 @@ import ( // https://www.freedesktop.org/software/systemd/man/systemd.service.html const SdNotifyReady = "READY=1" -func notifyReady(l *logrus.Logger) { +func notifyReady(l *slog.Logger) { sockName := os.Getenv("NOTIFY_SOCKET") if sockName == "" { - l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal") return } conn, err := net.DialTimeout("unixgram", sockName, time.Second) if err != nil { - l.WithError(err).Error("failed to connect to systemd notification socket") + l.Error("failed to connect to systemd notification socket", "error", err) return } defer conn.Close() err = conn.SetWriteDeadline(time.Now().Add(time.Second)) if err != nil { - l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + l.Error("failed to set the write deadline for the systemd notification socket", "error", err) return } if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { - l.WithError(err).Error("failed to signal the systemd notification socket") + l.Error("failed to signal the systemd notification socket", "error", err) return } - l.Debugln("notified systemd the service is ready") + l.Debug("notified systemd the service is ready") } diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go index e7758e09..48cfe949 100644 --- a/cmd/nebula/notify_notlinux.go +++ b/cmd/nebula/notify_notlinux.go @@ -3,8 +3,8 @@ package main -import "github.com/sirupsen/logrus" +import "log/slog" -func notifyReady(_ *logrus.Logger) { +func notifyReady(_ *slog.Logger) { // No init service to notify } diff --git a/config/config.go b/config/config.go index 0d1be128..5bf994a1 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "math" "os" "os/signal" @@ -16,7 +17,6 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "go.yaml.in/yaml/v3" ) @@ -26,11 +26,11 @@ type C struct { Settings map[string]any oldSettings map[string]any callbacks []func(*C) - l *logrus.Logger + l *slog.Logger reloadLock sync.Mutex } -func NewC(l *logrus.Logger) *C { +func NewC(l *slog.Logger) *C { return &C{ Settings: make(map[string]any), l: l, @@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool { newVals, err := yaml.Marshal(nv) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + c.l.Error("Error while marshaling new config", + "config_path", k, + "error", err, + ) } oldVals, err := yaml.Marshal(ov) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + c.l.Error("Error while marshaling old config", + "config_path", k, + "error", err, + ) } return string(newVals) != string(oldVals) @@ -154,7 +160,10 @@ func (c *C) ReloadConfig() { err := c.Load(c.path) if err != nil { - c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + c.l.Error("Error occurred while reloading config", + "config_path", c.path, + "error", err, + ) return } diff --git a/connection_manager.go b/connection_manager.go index 4c2f26ef..e7fc04cd 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,13 +5,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -47,10 +47,10 @@ type connectionManager struct { metricsTxPunchy metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { +func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ hostMap: hm, l: l, @@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.getInactivityTimeout() cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) if !initial { - cm.l.WithField("oldDuration", old). - WithField("newDuration", cm.getInactivityTimeout()). - Info("Inactivity timeout has changed") + cm.l.Info("Inactivity timeout has changed", + "oldDuration", old, + "newDuration", cm.getInactivityTimeout(), + ) } } @@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.dropInactive.Load() cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) if !initial { - cm.l.WithField("oldBool", old). - WithField("newBool", cm.dropInactive.Load()). - Info("Drop inactive setting has changed") + cm.l.Info("Drop inactive setting has changed", + "oldBool", old, + "newBool", cm.dropInactive.Load(), + ) } } } @@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo var err error index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.Error("failed to migrate relay to new hostinfo", "error", err) continue } switch r.Type { @@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo msg, err := req.Marshal() if err != nil { - cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.Error("failed to marshal Control message to migrate relay", "error", err) } else { cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - cm.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromAddr, - "relayTo": req.RelayToAddr, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddrs": newhostinfo.vpnAddrs}). - Info("send CreateRelayRequest") + cm.l.Info("send CreateRelayRequest", + "relayFrom", req.RelayFromAddr, + "relayTo", req.RelayToAddr, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddrs", newhostinfo.vpnAddrs, + ) } } } @@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") + cm.l.Debug("Not found in hostmap", "localIndex", localIndex) return doNothing, nil, nil } @@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "alive", "method": "passive"}, + ) } hostinfo.pendingDeletion.Store(false) @@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - Info("Tunnel status") + hostinfo.logger(cm.l).Info("Tunnel status", + "tunnelCheck", m{"state": "dead", "method": "active"}, + ) return deleteTunnel, hostinfo, nil } @@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim inactiveFor, isInactive := cm.isInactive(hostinfo, now) if isInactive { // Tunnel is inactive, tear it down - hostinfo.logger(cm.l). - WithField("inactiveDuration", inactiveFor). - WithField("primary", mainHostInfo). - Info("Dropping tunnel due to inactivity") + hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity", + "inactiveDuration", inactiveFor, + "primary", mainHostInfo, + ) return closeTunnel, hostinfo, primary } @@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim cm.sendPunch(hostinfo) } - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "testing", "method": "active"}, + ) } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues decision = sendTestPacket } else { - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l).Debugf("Hostinfo sadness") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Hostinfo sadness") } } @@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI return false //cert is still valid! yay! } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is blocked, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else if cm.intf.disconnectInvalid.Load() { - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else { //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open @@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrtVersion) if myCrt == nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("reason", "local certificate removed"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "reason", "local certificate removed", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } @@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { // if our certificate version is less than theirs, and we have a matching version available, rehandshake? if cs.getCertificate(peerCrt.Certificate.Version()) != nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("peerVersion", peerCrt.Certificate.Version()). - WithField("reason", "local certificate version lower than peer, attempting to correct"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "peerVersion", peerCrt.Certificate.Version(), + "reason", "local certificate version lower than peer, attempting to correct", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { hh.initiatingVersionOverride = peerCrt.Certificate.Version() }) @@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { } } if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "local certificate is not current", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } if curCrtVersion < cs.initiatingVersion { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "current cert version < pki.initiatingVersion"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "current cert version < pki.initiatingVersion", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..a015fba9 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/overlaytest" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/connection_state.go b/connection_state.go index db885d42..b85aebd4 100644 --- a/connection_state.go +++ b/connection_state.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/noiseutil" ) @@ -27,7 +26,7 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { +func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: diff --git a/control.go b/control.go index 75eccef1..ef58988b 100644 --- a/control.go +++ b/control.go @@ -3,13 +3,13 @@ package nebula import ( "context" "errors" + "log/slog" "net/netip" "os" "os/signal" "sync" "syscall" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" @@ -46,7 +46,7 @@ type Control struct { state RunState f *Interface - l *logrus.Logger + l *slog.Logger ctx context.Context cancel context.CancelFunc sshStart func() @@ -151,7 +151,7 @@ func (c *Control) Stop() { c.CloseAllTunnels(false) if err := c.f.Close(); err != nil { - c.l.WithError(err).Error("Close interface failed") + c.l.Error("Close interface failed", "error", err) } c.stateLock.Lock() c.state = StateStopped @@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() { rawSig := <-sigChan sig := rawSig.String() - c.l.WithField("signal", sig).Info("Caught signal, shutting down") + c.l.Info("Caught signal, shutting down", "signal", sig) c.Stop() } @@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). - Debug("Sending close tunnel message") + c.l.Debug("Sending close tunnel message", + "vpnAddrs", h.vpnAddrs, + "udpAddr", h.remote, + ) closed++ } diff --git a/control_test.go b/control_test.go index 558d8669..5e381c46 100644 --- a/control_test.go +++ b/control_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { f: &Interface{ hostMap: hm, }, - l: logrus.New(), + l: test.NewLogger(), } thi := c.GetHostInfoByVpnAddr(vpnIp, false) diff --git a/dns_server.go b/dns_server.go index 5b12b922..ff1369ab 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,6 +3,7 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "strconv" @@ -12,13 +13,12 @@ import ( "github.com/gaissmai/bart" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type dnsServer struct { sync.RWMutex - l *logrus.Logger + l *slog.Logger ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr @@ -55,7 +55,7 @@ type dnsServer struct { // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // watcher that tears the listener down on nebula shutdown. The returned // pointer is always non-nil, even on error. -func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { ds := &dnsServer{ l: l, ctx: ctx, @@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState c.RegisterReloadCallback(func(c *config.C) { if err := ds.reload(c, false); err != nil { - l.WithError(err).Error("Failed to reload DNS responder from config") + ds.l.Error("Failed to reload DNS responder from config", "error", err) } }) @@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso <-started } if err := srv.Shutdown(); err != nil { - d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder") + d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err) } } @@ -188,7 +188,7 @@ func (d *dnsServer) Start() { } }() - d.l.WithField("dnsListener", addr).Info("Starting DNS responder") + d.l.Info("Starting DNS responder", "dnsListener", addr) err := server.ListenAndServe() close(done) @@ -201,7 +201,7 @@ func (d *dnsServer) Start() { } if err != nil { - d.l.WithError(err).Warn("Failed to run the DNS responder") + d.l.Warn("Failed to run the DNS responder", "error", err) } } @@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { + debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug) // Per RFC 2308 §2.2, a name that exists but has no record of the requested // type must be answered with NOERROR and an empty answer section (NODATA), // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not @@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] - d.l.Debugf("Query for %s %s", qType, q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", qType, "name", q.Name) + } ip, nameExists := d.Query(q.Qtype, q.Name) if nameExists { anyNameExists = true @@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - d.l.Debugf("Query for TXT %s", q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", "TXT", "name", q.Name) + } ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) diff --git a/dns_server_test.go b/dns_server_test.go index e09d3fa9..dcea046c 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "io" + "log/slog" "net" "net/netip" "strconv" @@ -10,7 +10,6 @@ import ( "time" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {} func (stubDNSWriter) Hijack() {} func TestParsequery(t *testing.T) { - l := logrus.New() + l := slog.New(slog.DiscardHandler) hostMap := &HostMap{} ds := &dnsServer{ l: l, @@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) { func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { t.Helper() - l := logrus.New() - l.Out = io.Discard + sl := slog.New(slog.DiscardHandler) ds := &dnsServer{ - l: l, + l: sl, ctx: context.Background(), dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), @@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { } ds.mux = dns.NewServeMux() ds.mux.HandleFunc(".", ds.handleDnsRequest) - return ds, config.NewC(l) + return ds, config.NewC(nil) } func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 7729465b..93f200ac 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" @@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) { myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) - l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - l.Info("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - l.Info("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - l.Info("Trigger a handshake from both them and me via relay to them and me") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) - r.Log("Wait for a packet from them to me") - l.Info("Wait for a packet from them to me; myControl") + r.Log("Wait for a packet from them to me; myControl") r.RouteForAllUntilTxTun(myControl) - l.Info("Wait for a packet from them to me; theirControl") + r.Log("Wait for a packet from them to me; theirControl") r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") - l.Info("Wait until we remove extra tunnels") - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) retries := 60 for hostInfos > 6 && retries > 0 { hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) @@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) { } r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 39843efe..381ae897 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "fmt" "io" "net/netip" "os" @@ -12,15 +11,18 @@ import ( "testing" "time" + "log/slog" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" @@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr { return a } -func NewTestLogger() *logrus.Logger { - l := logrus.New() - +func NewTestLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - l.SetLevel(logrus.PanicLevel) - return l + return slog.New(slog.NewTextHandler(io.Discard, nil)) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// testLogLevelName returns the level name string accepted by logging.ApplyConfig +// for the current TEST_LOGS setting. Kept in sync with NewTestLogger. +func testLogLevelName() string { + switch os.Getenv("TEST_LOGS") { + case "2": + return "debug" + case "3": + return "trace" + case "": + return "info" + } + return "info" } diff --git a/examples/config.yml b/examples/config.yml index 5bb87d8e..b02b3d58 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -292,23 +292,17 @@ tun: # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. - #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some - # scenarios. Debug logging is also CPU intensive and will decrease performance overall. - # Only enable debug logging while actively investigating an issue. + # trace, debug, info, warn, or error. Default is info and is reloadable. + # fatal and panic are accepted for backwards compatibility and map to error. + #NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall. + # Only enable debug or trace logging while actively investigating an issue. level: info - # json or text formats currently available. Default is text + # json or text formats currently available. Default is text. format: text - # Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false + # Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false. #disable_timestamp: true - # timestamp format is specified in Go time format, see: - # https://golang.org/pkg/time/#pkg-constants - # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) - # default when `format: text`: - # when TTY attached: seconds since beginning of execution - # otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339) - # As an example, to log as RFC3339 with millisecond precision, set to: - #timestamp_format: "2006-01-02T15:04:05.000Z07:00" + # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable. #stats: #type: graphite diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 2f8efbfb..3f98fe3d 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -7,9 +7,9 @@ import ( "net" "os" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) @@ -64,8 +64,7 @@ pki: return err } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/firewall.go b/firewall.go index 93b16891..adecbe81 100644 --- a/firewall.go +++ b/firewall.go @@ -1,11 +1,13 @@ package nebula import ( + "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" + "log/slog" "net/netip" "reflect" "slices" @@ -16,7 +18,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -67,7 +68,7 @@ type Firewall struct { incomingMetrics firewallMetrics outgoingMetrics firewallMetrics - l *logrus.Logger + l *slog.Logger } type firewallMetrics struct { @@ -131,7 +132,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.InSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") + l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction) fw.InSendReject = false } @@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.OutSendReject = false default: - l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction) fw.OutSendReject = false } @@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoICMP, firewall.ProtoICMPv6: //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided if startPort != firewall.PortAny { - f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") + f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort) } startPort = firewall.PortAny endPort = firewall.PortAny @@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") + f.l.Info("Firewall rule added", + "firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}, + ) return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw startPort = firewall.PortAny endPort = firewall.PortAny if sPort != "" { - l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") + l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort) } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) @@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } if warning := r.sanity(); warning != nil { - l.Warnf("%s rule #%v; %s", table, i, warning) + l.Warn("firewall rule sanity check", + "table", table, + "rule", i, + "warning", warning, + ) } err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) @@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("dropping old conntrack entry, does not match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } delete(conntrack.Conns, fp) conntrack.Unlock() return false } - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("keeping old conntrack entry, does match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } c.rulesVersion = f.rulesVersion @@ -935,7 +941,7 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { +func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[string]any) @@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } - l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) + l.Warn("group was an array with a single value, converting to simple value", + "table", table, + "rule", i, + ) m["group"] = v[0] } diff --git a/firewall/cache.go b/firewall/cache.go index a4ffc100..ba4b9732 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -2,10 +2,9 @@ package firewall import ( "context" + "log/slog" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // ConntrackCache is used as a local routine cache to know if a given flow @@ -16,15 +15,17 @@ type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 + l *slog.Logger cache ConntrackCache } -func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker { +func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } c := &ConntrackCacheTicker{ + l: l, cache: ConntrackCache{}, } @@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) { // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { +func (c *ConntrackCacheTicker) Get() ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if l.Level == logrus.DebugLevel { - l.WithField("len", ll).Debug("resetting conntrack cache") + if c.l.Enabled(context.Background(), slog.LevelDebug) { + c.l.Debug("resetting conntrack cache", "len", ll) } c.cache = make(ConntrackCache, ll) } diff --git a/firewall/cache_test.go b/firewall/cache_test.go new file mode 100644 index 00000000..ab807984 --- /dev/null +++ b/firewall/cache_test.go @@ -0,0 +1,69 @@ +package firewall + +import ( + "bytes" + "log/slog" + "strings" + "testing" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// The tests below pin the log format produced by ConntrackCacheTicker.Get +// so changes cannot silently break what operators are grepping for. The +// ticker's internal state (cache + cacheTick) is poked directly to avoid +// racing a goroutine-driven tick in tests. + +func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker { + t.Helper() + c := &ConntrackCacheTicker{ + l: l, + cache: make(ConntrackCache, cacheLen), + } + for i := 0; i < cacheLen; i++ { + c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{} + } + c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path + return c +} + +func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 3) + c.Get() + + assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) +} + +func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 2) + c.Get() + + assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) +} + +func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) + + c := newFixedTicker(t, l, 5) + c.Get() + + assert.Empty(t, buf.String()) +} + +func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 0) + c.Get() + + assert.Empty(t, buf.String()) +} diff --git a/firewall_test.go b/firewall_test.go index a2133760..cbf090fd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -3,13 +3,13 @@ package nebula import ( "bytes" "errors" + "log/slog" "math" "net/netip" "testing" "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) @@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ @@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) { } func TestFirewall_DropV6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_Drop3V6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } func TestFirewall_ICMPPortBehavior(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } func TestFirewall_DropIPSpoofing(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) @@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) { cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") @@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) { require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") @@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) { l := test.NewLogger() // Test adding tcp rule - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule no port - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr ipv6 cidr6 := netip.MustParsePrefix("fd00::/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} @@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) // Ensure group array of 1 is converted and a warning is printed c := map[string]any{ @@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) { } r, err := convertRule(l, c, "test", 1) - assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "table=test") + assert.Contains(t, ob.String(), "rule=1") require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) @@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) { } func TestFirewall_convertRuleSanity(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) noWarningPlease := []map[string]any{ {"group": "group1"}, @@ -1386,7 +1377,7 @@ type testsetup struct { fw *Firewall } -func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { +func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup { c := dummyCert{ name: "me", networks: myPrefixes, @@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } -func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { +func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) @@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { t.Parallel() - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myPrefix := netip.MustParsePrefix("1.1.1.1/8") // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out diff --git a/go.mod b/go.mod index 169cf1ca..0de2df7d 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index d56177b7..aad164c7 100644 --- a/go.sum +++ b/go.sum @@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= -github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= diff --git a/handshake_ix.go b/handshake_ix.go index f081eb8c..a086960e 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,11 +2,12 @@ package nebula import ( "bytes" + "context" + "log/slog" "net/netip" "time" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -18,8 +19,11 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { crt := cs.getCertificate(v) if crt == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } crtHs := cs.getHandshakeBytes(v) if crtHs == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate handshake bytes is available") + f.l.Error("Unable to handshake with host because no certificate handshake bytes is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } - ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } hh.hostinfo.ConnectionState = ci @@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("certVersion", v). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "certVersion", v, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", cs.initiatingVersion). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", cs.initiatingVersion, + ) return } - ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVpnNetworks", rc.Networks()). - WithField("certFingerprint", fp) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}), + slog.Any("certVpnNetworks", rc.Networks()), + slog.String("certFingerprint", fp), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) return } @@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // We started off using the wrong certificate version, lets see if we can match the version that was sent to us myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) if myCertOtherVersion == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithError(err).WithFields(m{ - "from": via, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - "cert": remoteCert, - }).Debug("Might be unable to handshake with host due to missing certificate version") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Might be unable to handshake with host due to missing certificate version", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) } } else { // Record the certificate we are actually using @@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "cert", remoteCert, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } vpnAddrs[i] = network.Addr() @@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, + "from", via, + ) + } return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) }, } - msgRxL := f.l.WithFields(m{ - "vpnAddrs": vpnAddrs, - "from": via, - "certName": certName, - "certVersion": certVersion, - "fingerprint": fingerprint, - "issuer": issuer, - "initiatorIndex": hs.Details.InitiatorIndex, - "responderIndex": hs.Details.ResponderIndex, - "remoteIndex": h.RemoteIndex, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - }) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") @@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { - msgRxL.WithField("myCertVersion", ci.myCert.Version()). - Error("Unable to handshake with host because no certificate handshake bytes is available") + msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available", + "myCertVersion", ci.myCert.Version(), + ) return } @@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { err := f.outside.WriteTo(msg, via.UdpAddr) if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") + f.l.Error("Failed to send handshake message", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + "error", err, + ) } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) } return } else { @@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("oldHandshakeTime", existing.lastHandshakeTime). - WithField("newHandshakeTime", hostinfo.lastHandshakeTime). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake too old") + f.l.Info("Handshake too old", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). - Error("Failed to add HostInfo due to localIndex collision") + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "localIndex", hostinfo.localIndexId, + "collision", existing.vpnAddrs, + ) return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to add HostInfo to HostMap") + f.l.Error("Failed to add HostInfo to HostMap", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } } @@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if !via.IsRelayed { err = f.outside.WriteTo(msg, via.UdpAddr) - log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + log := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) if err != nil { - log.WithError(err).Error("Failed to send handshake") + log.Error("Failed to send handshake", "error", err) } else { log.Info("Handshake message sent") } @@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // it's correctly marked as working. via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) } f.connectionManager.AddTrafficWatch(hostinfo) @@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + ) + } return false } } @@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "header", h, + ) // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // This should be impossible in IX but just in case, if we get here then there is no chance to recover // the handshake state machine. Tear it down @@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true @@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("certFingerprint", fp). - WithField("certVpnNetworks", rc.Networks()) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("vpnAddrs", hostinfo.vpnAddrs), + slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}), + slog.String("certFingerprint", fp), + slog.Any("certVpnNetworks", rc.Networks()), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return true } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cert", remoteCert, + ) return true } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "cert", remoteCert, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe // Ensure the right host responded if !correctHostResponded { - f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // Release our old handshake from pending, it should not continue f.handshakeManager.DeleteHostInfo(hostinfo) @@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(via) - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnNetworks", vpnNetworks). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") + f.l.Info("Blocked addresses for handshakes", + "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(), + "vpnNetworks", vpnNetworks, + "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()), + ) // Swap the packet store to benefit the original intended recipient newHH.packetStore = hh.packetStore @@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") } else { @@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", + "count", len(hh.packetStore), + ) } if len(hh.packetStore) > 0 { diff --git a/handshake_manager.go b/handshake_manager.go index 25a59b6d..8040ec2e 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,13 +6,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "log/slog" "net/netip" "slices" "sync" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -59,7 +59,7 @@ type HandshakeManager struct { metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface - l *logrus.Logger + l *slog.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr @@ -78,32 +78,32 @@ type HandshakeHostInfo struct { hostinfo *HostInfo } -func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { if len(hh.packetStore) < 100 { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", true, + ) } } else { m.dropped.Inc(1) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", false). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", false, + ) } } } -func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, @@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { - hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via) return } } @@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). - WithField("initiatorIndex", hh.hostinfo.localIndexId). - WithField("remoteIndex", hh.hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). - Info("Handshake timed out") + hh.hostinfo.logger(hm.l).Info("Handshake timed out", + "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), + "initiatorIndex", hh.hostinfo.localIndexId, + "remoteIndex", hh.hostinfo.remoteIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "durationNs", time.Since(hh.startTime).Nanoseconds(), + ) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(hm.l).WithField("udpAddr", addr). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") + hostinfo.logger(hm.l).Error("Failed to send handshake message", + "udpAddr", addr, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "error", err, + ) } else { sentTo = append(sentTo, addr) @@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message sent") - } else if hm.l.Level >= logrus.DebugLevel { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Debug("Handshake message sent") + hostinfo.logger(hm.l).Info("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) + } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(hm.l).Debug("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) } if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay through the host I'm trying to connect to @@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String()) hm.f.Handshake(relay) continue } @@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) } m := NebulaControl{ @@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": idx, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) } } continue @@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered switch existingRelay.State { case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String()) // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, @@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) } case PeerRequested: // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relay). - Errorf("Relay unexpected state") + hostinfo.logger(hm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) } } @@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) @@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. @@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { hm.indexes = map[uint32]*HandshakeHostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Pending hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Pending hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } } @@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() { // Utility functions below -func generateIndex(l *logrus.Logger) (uint32, error) { +func generateIndex(l *slog.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero @@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) { for index == 0 { _, err := rand.Read(b) if err != nil { - l.Errorln(err) + l.Error("Failed to generate index", "error", err) return 0, err } index = binary.BigEndian.Uint32(b) } - if l.Level >= logrus.DebugLevel { - l.WithField("index", index). - Debug("Generated index") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Generated index", "index", index) } return index, nil } diff --git a/hostmap.go b/hostmap.go index 25181d83..08acd1be 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,9 +1,11 @@ package nebula import ( + "context" "encoding/json" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -13,10 +15,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" ) const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address @@ -60,7 +62,7 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - l *logrus.Logger + l *slog.Logger } // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay @@ -313,7 +315,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { +func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap { hm := newHostMap(l) hm.reload(c, true) @@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm.reload(c, false) }) - l.WithField("preferredRanges", hm.GetPreferredRanges()). - Info("Main HostMap created") + l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges()) return hm } -func newHostMap(l *logrus.Logger) *HostMap { +func newHostMap(l *slog.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, @@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { - hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + hm.l.Warn("Failed to parse preferred ranges, ignoring", + "error", err, + "range", rawPreferredRanges, + ) continue } @@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { oldRanges := hm.preferredRanges.Swap(&preferredRanges) if !initial { - hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + hm.l.Info("preferred_ranges changed", + "oldPreferredRanges", *oldRanges, + "newPreferredRanges", preferredRanges, + ) } } } @@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad hm.Indexes = map[uint32]*HostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.Hosts), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } if isLastHostinfo { @@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). - Debug("Hostmap vpnIp added") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap vpnIp added", + "hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}, + ) } } @@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica } } -func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { +// logger returns a derived slog.Logger with per-hostinfo fields pre-bound. +func (i *HostInfo) logger(l *slog.Logger) *slog.Logger { if i == nil { - return logrus.NewEntry(l) + return l } - li := l.WithField("vpnAddrs", i.vpnAddrs). - WithField("localIndex", i.localIndexId). - WithField("remoteIndex", i.remoteIndexId) + li := l.With( + "vpnAddrs", i.vpnAddrs, + "localIndex", i.localIndexId, + "remoteIndex", i.remoteIndexId, + ) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Certificate.Name()) + li = li.With("certName", peerCert.Certificate.Name()) } } @@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) - if l.Level >= logrus.TraceLevel { - l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName", + "interfaceName", i.Name, + "allow", allow, + ) } if !allow { @@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { } if !addr.IsValid() { - if l.Level >= logrus.DebugLevel { - l.WithField("localAddr", rawAddr).Debug("addr was invalid") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("addr was invalid", "localAddr", rawAddr) } continue } @@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { isAllowed := allowList.Allow(addr) - if l.Level >= logrus.TraceLevel { - l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow", + "localAddr", addr, + "allowed", isAllowed, + ) } if !isAllowed { continue diff --git a/hostmap_test.go b/hostmap_test.go index e34a4ad0..2bd7bd43 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_reload(t *testing.T) { l := test.NewLogger() - c := config.NewC(l) + c := config.NewC(test.NewLogger()) hm := NewHostMapFromConfig(l, c) diff --git a/inside.go b/inside.go index 0d53f952..68cb38ec 100644 --- a/inside.go +++ b/inside.go @@ -1,9 +1,10 @@ package nebula import ( + "context" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -14,8 +15,11 @@ import ( func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while validating outbound packet", + "packet", packet, + "error", err, + ) } return } @@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) if err != nil { - f.l.WithError(err).Error("Failed to forward to tun") + f.l.Error("Failed to forward to tun", "error", err) } } // Otherwise, drop. On linux, we should never see these packets - Linux @@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if hostinfo == nil { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", fwPacket.RemoteAddr). - WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", + "vpnAddr", fwPacket.RemoteAddr, + "fwPacket", fwPacket, + ) } return } @@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } else { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping outbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } } } @@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { _, err := f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } } @@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } if len(out) > iputil.MaxRejectPacketSize { - if f.l.GetLevel() >= logrus.InfoLevel { - f.l. - WithField("packet", packet). - WithField("outPacket", out). - Info("rejectOutside: packet too big, not sending") + if f.l.Enabled(context.Background(), slog.LevelInfo) { + f.l.Info("rejectOutside: packet too big, not sending", + "packet", packet, + "outPacket", out, + ) } return } @@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac // This would also need to interact with unsafe_route updates through reloading the config or // use of the use_system_route_table option - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("destination", destinationAddr). - WithField("originalGateway", gatewayAddr). - Debugln("Calculated gateway for ECMP not available, attempting other gateways") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways", + "destination", destinationAddr, + "originalGateway", gatewayAddr, + ) } for i := range gateways { @@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { - f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("fwPacket", fp). - WithField("reason", dropReason). - Debugln("dropping cached packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping cached packet", + "fwPacket", fp, + "reason", dropReason, + ) } return } @@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message }) if hostInfo == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", vpnAddr). - Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes", + "vpnAddr", vpnAddr, + ) } return } @@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo, if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } - via.logger(f.l). - WithField("outCap", cap(out)). - WithField("payloadLen", len(ad)). - WithField("headerLen", len(out)). - WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). - Error("SendVia out buffer not large enough for relay") + via.logger(f.l).Error("SendVia out buffer not large enough for relay", + "outCap", cap(out), + "payloadLen", len(ad), + "headerLen", len(out), + "cipherOverhead", via.ConnectionState.eKey.Overhead(), + ) return } @@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo, via.ConnectionState.writeLock.Unlock() } if err != nil { - via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") + via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) return } err = f.writers[0].WriteTo(out, via.remote) if err != nil { - via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") + via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } f.connectionManager.RelayUsed(relay.LocalIndex) } @@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) } } @@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType ci.writeLock.Unlock() } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).WithField("counter", c). - WithField("attemptedCounter", c). - Error("Failed to encrypt outgoing packet") + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", err, + "udpAddr", remote, + "counter", c, + "attemptedCounter", c, + ) return } if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else { // Try to send via a relay @@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) diff --git a/interface.go b/interface.go index 6d040884..5fedcdd3 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "sync" "sync/atomic" @@ -12,7 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -46,7 +47,7 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *logrus.Logger + l *slog.Logger } type Interface struct { @@ -100,7 +101,7 @@ type Interface struct { messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics - l *logrus.Logger + l *slog.Logger } type EncWriter interface { @@ -223,13 +224,16 @@ func (f *Interface) activate() error { addr, err := f.outside.LocalAddr() if err != nil { - f.l.WithError(err).Error("Failed to get udp listen address") + f.l.Error("Failed to get udp listen address", "error", err) } - f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). - WithField("build", f.version).WithField("udpAddr", addr). - WithField("boringcrypto", boringEnabled()). - Info("Nebula interface is active") + f.l.Info("Nebula interface is active", + "interface", f.inside.Name(), + "networks", f.myVpnNetworks, + "build", f.version, + "udpAddr", addr, + "boringcrypto", boringEnabled(), + ) if f.routines > 1 { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { @@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) + ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} @@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) { nb := make([]byte, 12, 12) err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) }) if err != nil && !f.closed.Load() { - f.l.WithError(err).Error("Error while reading inbound packet, closing") + f.l.Error("Error while reading inbound packet, closing", "error", err) f.onFatal(err) } - f.l.Debugf("underlay reader %v is done", i) + f.l.Debug("underlay reader is done", "reader", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { n, err := reader.Read(packet) if err != nil { if !f.closed.Load() { - f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") + f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) f.onFatal(err) } break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } - f.l.Debugf("overlay reader %v is done", i) + f.l.Debug("overlay reader is done", "reader", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { if initial || c.HasChanged("pki.disconnect_invalid") { f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) if !initial { - f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load()) } } } @@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) { fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { - f.l.WithError(err).Error("Error while creating firewall during reload") + f.l.Error("Error while creating firewall during reload", "error", err) return } @@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Warn("firewall rulesVersion has overflowed, resetting conntrack") + f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } else { fw.Conntrack = conntrack } @@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Info("New firewall has been installed") + f.l.Info("New firewall has been installed", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } func (f *Interface) reloadSendRecvError(c *config.C) { @@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } - f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). - Info("Loaded send_recv_error config") + f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String()) } } @@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) { } } - f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). - Info("Loaded accept_recv_error config") + f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String()) } } @@ -527,7 +531,7 @@ func (f *Interface) Close() error { for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket") + f.l.Error("Error while closing udp socket", "error", err, "writer", i) errs = append(errs, err) } } diff --git a/lighthouse.go b/lighthouse.go index 50140e9e..6034e68c 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -15,10 +16,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -76,12 +77,12 @@ type LightHouse struct { metrics *MessageMetrics metricHolepunchTx metrics.Counter - l *logrus.Logger + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, case *util.ContextualError: v.Log(l) case error: - l.WithError(err).Error("failed to reload lighthouse") + l.Error("failed to reload lighthouse", "error", err) } }) @@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() if lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithField("addr", rawAddr).WithField("entry", i+1). - Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") + lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range", + "addr", rawAddr, + "entry", i+1, + ) continue } @@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { - lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) + lh.l.Info("lighthouse.interval changed", + "interval", lh.interval.Load(), + ) if lh.updateCancel != nil { // May not always have a running routine @@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { for _, v := range c.GetStringSlice("relay.relays", nil) { configRIP, err := netip.ParseAddr(v) if err != nil { - lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + lh.l.Warn("Parse relay from config failed", + "relay", v, + "error", err, + ) } else { - lh.l.WithField("relay", v).Info("Read relay from config") + lh.l.Info("Read relay from config", "relay", v) relaysForMe = append(relaysForMe, configRIP) } } @@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { } if !lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). - Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") + lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not", + "vpnAddr", addr, + "networks", lh.myVpnNetworks, + ) } out[i] = addr } @@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). - Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") + lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work", + "vpnAddr", vpnAddr, + "networks", lh.myVpnNetworks, + "entry", i+1, + ) } vals, ok := v.([]any) @@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { + debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug) for _, addr := range allVpnAddrs { srm := lh.addrMap[addr] if srm == rm { delete(lh.addrMap, addr) - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", addr) + if debugEnabled { + lh.l.Debug("deleting from lighthouse", "vpnAddr", addr) } } } @@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddrs", vpnAddrs, + "udpAddr", to, + "allow", allow, + ) } if !allow { return false @@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { udpAddr := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { udpAddr := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if v == cert.Version1 { if !addr.Is4() { - lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). - Error("Can't query lighthouse for v6 address using a v1 protocol") + lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol", + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } @@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v1Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v1 query payload") + lh.l.Error("Failed to marshal lighthouse v1 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v2Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v2 query payload") + lh.l.Error("Failed to marshal lighthouse v2 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried++ } else { - lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + lh.l.Debug("unsupported protocol version", + "op", "query", + "queryVpnAddr", addr, + "version", v, + ) continue } } @@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() { if v == cert.Version1 { if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { - lh.l.WithField("lighthouseAddr", lhVpnAddr). - Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address", + "lighthouseAddr", lhVpnAddr, + ) continue } var relays []uint32 @@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() { v1Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v1 update") + lh.l.Error("Error while marshaling for lighthouse v1 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() { v2Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v2 update") + lh.l.Error("Error while marshaling for lighthouse v2 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() { updated++ } else { - lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + lh.l.Debug("unsupported protocol version", + "op", "update", + "version", v, + ) continue } } @@ -983,7 +1024,7 @@ type LightHouseHandler struct { out []byte pb []byte meta *NebulaMeta - l *logrus.Logger + l *slog.Logger } func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { @@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Failed to unmarshal lighthouse packet") + lhh.l.Error("Failed to unmarshal lighthouse packet", + "error", err, + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Invalid lighthouse update") + lhh.l.Error("Invalid lighthouse update", + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } @@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I don't answer queries, but received from: ", addr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I don't answer queries, but received one", "from", addr) } return } queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). - Debugln("Dropping malformed HostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Dropping malformed HostQuery", + "from", fromVpnAddrs, + "details", n.Details, + ) } return } if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). - Debugln("invalid vpn addr for v1 handleHostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("invalid vpn addr for v1 handleHostQuery", + "vpnAddrs", fromVpnAddrs, + "queryVpnAddr", queryVpnAddr, + ) } return } @@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.Error("Failed to marshal lighthouse host query reply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd if ok { whereToPunch = newDest } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unable to punch to host, no addresses in common", + "to", crt.Networks(), + ) } } } @@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.Error("Failed to marshal lighthouse host was queried for", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("version", v).Debug("unsupported protocol version") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unsupported protocol version", + "op", "coalesceAnswers", + "version", v, + ) } } } @@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Error("dropping malformed HostQueryReply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) } return } @@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs) } return } @@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp //Simple check that the host sent this not someone else, if detailsVpnAddr is filled if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Host sent invalid update", + "vpnAddrs", fromVpnAddrs, + "answer", detailsVpnAddr, + ) } return } @@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp switch useVersion { case cert.Version1: if !fromVpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message", + "vpnAddrs", fromVpnAddrs, + ) return } vpnAddrB := fromVpnAddrs[0].As4() @@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp case cert.Version2: // do nothing, we want to send a blank message default: - lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + lhh.l.Error("invalid protocol version", "useVersion", useVersion) return } ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.Error("Failed to marshal lighthouse host update ack", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("dropping invalid HostPunchNotification", + "details", n.Details, + "error", err, + ) } return } @@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Punching", + "vpnPeer", vpnPeer, + "logVpnAddr", logVpnAddr, + ) } } @@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn if lhh.lh.punchy.GetRespond() { go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Sending a nebula test packet", + "vpnAddr", detailsVpnAddr, + ) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine diff --git a/logger.go b/logger.go deleted file mode 100644 index aaf6f29c..00000000 --- a/logger.go +++ /dev/null @@ -1,45 +0,0 @@ -package nebula - -import ( - "fmt" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" -) - -func configLogger(l *logrus.Logger, c *config.C) error { - // set up our logging level - logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) - if err != nil { - return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) - } - l.SetLevel(logLevel) - - disableTimestamp := c.GetBool("logging.disable_timestamp", false) - timestampFormat := c.GetString("logging.timestamp_format", "") - fullTimestamp := (timestampFormat != "") - if timestampFormat == "" { - timestampFormat = time.RFC3339 - } - - logFormat := strings.ToLower(c.GetString("logging.format", "text")) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{ - TimestampFormat: timestampFormat, - FullTimestamp: fullTimestamp, - DisableTimestamp: disableTimestamp, - } - case "json": - l.Formatter = &logrus.JSONFormatter{ - TimestampFormat: timestampFormat, - DisableTimestamp: disableTimestamp, - } - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) - } - - return nil -} diff --git a/logging/logger.go b/logging/logger.go new file mode 100644 index 00000000..bbc10bb3 --- /dev/null +++ b/logging/logger.go @@ -0,0 +1,233 @@ +// Package logging wires the nebula runtime-reconfigurable slog handler used +// by nebula.Main and the nebula CLI binaries. Callers build a logger with +// NewLogger, then call ApplyConfig at startup and from a config reload +// callback to push logging.level, logging.format, and +// logging.disable_timestamp changes onto the logger without rebuilding it. +package logging + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + "sync/atomic" + "time" +) + +// Config is the subset of *config.C that ApplyConfig reads. Declaring it +// here keeps the logging package from depending on config directly, which +// would cycle through the shared test helpers (test.NewLogger imports +// logging, and config's tests import test). *config.C satisfies this +// interface structurally with no adapter. +type Config interface { + GetString(key, def string) string + GetBool(key string, def bool) bool +} + +// LevelTrace is a custom slog level below Debug, used when logging.level is +// "trace". slog has no builtin trace level; the value is one step below +// slog.LevelDebug in slog's 4-point spacing. +const LevelTrace = slog.Level(-8) + +// NewLogger returns a *slog.Logger whose level, format, and timestamp +// emission can be reconfigured at runtime via ApplyConfig and the SSH debug +// commands. The default configuration is info-level text output so log +// calls made before ApplyConfig runs still produce output. Timestamps +// follow slog's default RFC3339Nano format; set logging.disable_timestamp +// in config to suppress them. +// +// ApplyConfig and the SSH commands discover the reconfig surface via +// structural type-assertion on l.Handler(), so replacement implementations +// (tests, platform-specific sinks) need only implement the subset of +// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)} +// they care about. Callers that pass a plain *slog.Logger without these +// methods get a silent no-op; reconfiguration is always opt-in. +func NewLogger(w io.Writer) *slog.Logger { + return slog.New(NewHandler(w)) +} + +// NewHandler builds the *Handler that NewLogger wraps. Exported for +// platform-specific sinks (notably cmd/nebula-service/logs_windows.go) +// that want to wrap the handler with extra behavior, such as tagging each +// record with its Event Log severity, while still benefiting from all the +// level / format / timestamp / WithAttrs machinery implemented here. +func NewHandler(w io.Writer) *Handler { + root := &handlerRoot{} + root.level.Set(slog.LevelInfo) + opts := &slog.HandlerOptions{Level: &root.level} + return &Handler{ + root: root, + text: slog.NewTextHandler(w, opts), + json: slog.NewJSONHandler(w, opts), + } +} + +// handlerRoot carries the reconfiguration state shared by every logger +// derived from a NewHandler call. All fields are consulted on the log +// path and updated lock-free. +type handlerRoot struct { + level slog.LevelVar + disableTimestamp atomic.Bool + // jsonMode picks which of the pre-derived inner handlers Handler.Handle + // dispatches to. Flipping it propagates instantly to every derived logger + // without rebuilding or chain-replaying anything. + jsonMode atomic.Bool +} + +// Handler is the slog.Handler returned by NewHandler. It holds two +// pre-derived slog handlers -- one text, one json -- both built from the +// same accumulated WithAttrs/WithGroup state. Handle picks which one to +// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes +// effect immediately across the whole process without having to rebuild +// any derived loggers. +type Handler struct { + root *handlerRoot + text slog.Handler + json slog.Handler +} + +func (h *Handler) Enabled(_ context.Context, l slog.Level) bool { + return h.root.level.Level() <= l +} + +func (h *Handler) Handle(ctx context.Context, r slog.Record) error { + if h.root.disableTimestamp.Load() { + r.Time = time.Time{} + } + if h.root.jsonMode.Load() { + return h.json.Handle(ctx, r) + } + return h.text.Handle(ctx, r) +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithAttrs(attrs), + json: h.json.WithAttrs(attrs), + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithGroup(name), + json: h.json.WithGroup(name), + } +} + +// SetLevel updates the effective log level. Propagates to every derived +// logger via the shared LevelVar. +func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) } + +// GetLevel reports the current log level. +func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() } + +// SetFormat flips the output format atomically. Valid formats are "text" +// and "json". Every derived logger sees the new format on its next Handle +// call; no rebuild or registration is required. +func (h *Handler) SetFormat(format string) error { + switch format { + case "text": + h.root.jsonMode.Store(false) + case "json": + h.root.jsonMode.Store(true) + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"}) + } + return nil +} + +// GetFormat reports the currently selected format name. +func (h *Handler) GetFormat() string { + if h.root.jsonMode.Load() { + return "json" + } + return "text" +} + +// SetDisableTimestamp toggles whether Handle zeroes r.Time before +// dispatching (slog's builtin text/json handlers skip emitting the time +// attribute on a zero time). +func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) } + +// ApplyConfig reads logging.level, logging.format, and (optionally) +// logging.disable_timestamp from c and applies them to l. The reconfig +// surface is discovered via structural type-assertion on l.Handler(), so +// foreign handlers silently opt out of whichever capabilities they do not +// implement. +// +// nebula.Main does NOT call this function on your behalf; callers that want +// config-driven log level / format / timestamp updates invoke it at +// startup and register it as a reload callback themselves. This keeps the +// library from mutating an embedder's logger without their say-so. +func ApplyConfig(l *slog.Logger, c Config) error { + h := l.Handler() + + lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return err + } + if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok { + ls.SetLevel(lvl) + } + + format := strings.ToLower(c.GetString("logging.format", "text")) + if fs, ok := h.(interface{ SetFormat(string) error }); ok { + if err := fs.SetFormat(format); err != nil { + return err + } + } + + if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok { + ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false)) + } + return nil +} + +// ParseLevel converts a config-string level name ("trace", "debug", "info", +// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and +// "panic" are accepted for backwards compatibility with pre-slog configs +// and both map to slog.LevelError. +func ParseLevel(s string) (slog.Level, error) { + switch s { + case "trace": + return LevelTrace, nil + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + case "fatal", "panic": + return slog.LevelError, nil + default: + return 0, fmt.Errorf("not a valid logging level: %q", s) + } +} + +// LevelName returns a human-readable name for a slog.Level matching the +// strings accepted by ParseLevel. +func LevelName(l slog.Level) string { + switch { + case l <= LevelTrace: + return "trace" + case l <= slog.LevelDebug: + return "debug" + case l <= slog.LevelInfo: + return "info" + case l <= slog.LevelWarn: + return "warn" + default: + return "error" + } +} diff --git a/logging/logger_bench_test.go b/logging/logger_bench_test.go new file mode 100644 index 00000000..eb29c1c3 --- /dev/null +++ b/logging/logger_bench_test.go @@ -0,0 +1,90 @@ +package logging + +import ( + "context" + "io" + "log/slog" + "testing" +) + +// BenchmarkLogger_* compare the handler returned by NewLogger against a +// stock slog text handler. The key thing we care about is the per-log +// cost on a logger that has been derived via .With(), because that is the +// shape subsystems store on their structs (HostInfo.logger(), +// lh.l.With("subsystem", ...), etc.) and call from hot paths. + +func BenchmarkLogger_Stock_RootInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_RootInfo(b *testing.B) { + l := NewLogger(io.Discard) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +// Gated-off-path benchmarks: mimic the typical hot-path shape +// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below +// the active level. This is the dominant pattern in inside.go/outside.go and +// what we pay on every packet. +func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} + +func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} diff --git a/main.go b/main.go index 0ac63dfa..f692f317 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,13 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "runtime/debug" "strings" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" @@ -20,7 +20,7 @@ import ( type m = map[string]any -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg buildVersion = moduleVersion() } - l := logger - l.Formatter = &logrus.TextFormatter{ - FullTimestamp: true, - } - // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(c.Settings) @@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // Print the final config - l.Println(string(b)) + l.Info(string(b)) } - err := configLogger(l, c) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) - } - - c.RegisterReloadCallback(func(c *config.C) { - err := configLogger(l, c) - if err != nil { - l.WithError(err).Error("Failed to configure the logger") - } - }) - pki, err := NewPKIFromConfig(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) @@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") + l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") + l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err) sshStart = nil } } @@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines = 1 } if routines > 1 { - l.WithField("routines", routines).Info("Using multiple routines") + l.Info("Using multiple routines", "routines", routines) } } else { // deprecated and undocumented @@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg udpQueues := c.GetInt("listen.routines", 1) routines = max(tunQueues, udpQueues) if routines != 1 { - l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") + l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines) } } @@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg conntrackCacheTimeout = 1 * time.Second } if conntrackCacheTimeout > 0 { - l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout) } var tun overlay.Device @@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { - l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port))) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) @@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) if err != nil { - l.WithError(err).Warn("Failed to start DNS responder") + l.Warn("Failed to start DNS responder", "error", err) } ifConfig := &InterfaceConfig{ diff --git a/outside.go b/outside.go index eba9d887..1e00a0a9 100644 --- a/outside.go +++ b/outside.go @@ -1,15 +1,16 @@ package nebula import ( + "context" "encoding/binary" "errors" + "log/slog" "net/netip" "time" "github.com/google/gopacket/layers" "golang.org/x/net/ipv6" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "golang.org/x/net/ipv4" @@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) + f.l.Info("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) } return } @@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Refusing to process double encrypted packet", "from", via) } return } @@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } @@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) return } @@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) return } } @@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") + hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt test packet") + hostinfo.logger(f.l).Error("Failed to decrypt test packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt CloseTunnel packet") + hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", + "error", err, + "from", via, + "packet", packet, + ) return } - hostinfo.logger(f.l).WithField("from", via). - Info("Close tunnel received, tearing down.") + hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) f.closeTunnel(hostinfo) return @@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt Control packet") + hostinfo.logger(f.l).Error("Failed to decrypt Control packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) + } return } @@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") - return - } - - if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Info("Host roamed to new udp ip/port.") + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote", + "suppressSeconds", RoamingSuppressSeconds, + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) + } + return + } + + hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.", + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote hostinfo.SetRemote(via.UdpAddr) @@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - hostinfo.logger(f.l).WithField("header", h). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) + } return nil, errors.New("out of window packet") } @@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) return false } err = newPacket(out, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). - Warnf("Error while validating inbound packet") + hostinfo.logger(f.l).Warn("Error while validating inbound packet", + "error", err, + "packet", out, + ) return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) + } return false } @@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping inbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping inbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } return false } @@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } return true } @@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) _ = f.outside.WriteTo(b, endpoint) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", index). - WithField("udpAddr", endpoint). - Debug("Recv error sent") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error sent", + "index", index, + "udpAddr", endpoint, + ) } } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received, ignoring") + f.l.Debug("Recv error received, ignoring", + "index", h.RemoteIndex, + "udpAddr", addr, + ) return } - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error received", + "index", h.RemoteIndex, + "udpAddr", addr, + ) } hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) if hostinfo == nil { - f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") + f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex) return } if hostinfo.remote.IsValid() && hostinfo.remote != addr { - f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + f.l.Info("Someone spoofing recv_errors?", + "addr", addr, + "hostinfoRemote", hostinfo.remote, + ) return } diff --git a/test/tun.go b/overlay/overlaytest/noop.go similarity index 68% rename from test/tun.go rename to overlay/overlaytest/noop.go index fb32782f..956da7dd 100644 --- a/test/tun.go +++ b/overlay/overlaytest/noop.go @@ -1,4 +1,6 @@ -package test +// Package overlaytest provides fakes of overlay.Device for tests that do +// not want to touch a real tun device or route table. +package overlaytest import ( "errors" @@ -8,6 +10,9 @@ import ( "github.com/slackhq/nebula/routing" ) +// NoopTun is an overlay.Device that silently discards every read and write. +// Useful in tests that need to construct a nebula Interface but do not +// exercise the datapath. type NoopTun struct{} func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { diff --git a/overlay/route.go b/overlay/route.go index 61989581..c6403f91 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -2,6 +2,7 @@ package overlay import ( "fmt" + "log/slog" "math" "net" "net/netip" @@ -9,7 +10,6 @@ import ( "strconv" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -48,11 +48,14 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { +func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { - l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) + l.Warn("route MTU is not supported on this platform", + "goos", runtime.GOOS, + "route", r, + ) } gateways := r.Via diff --git a/overlay/route_test.go b/overlay/route_test.go index 9a959a55..f9d9dcd9 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") @@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 3) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("192.168.86.1") diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..3af1e189 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,10 +2,10 @@ package overlay import ( "fmt" + "log/slog" "net" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) @@ -22,9 +22,9 @@ func (e *NameError) Error() string { } // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, vpnNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..9cbb64be 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,12 +6,12 @@ package overlay import ( "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -23,10 +23,10 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..524ef0cd 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" @@ -14,7 +15,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -30,7 +30,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") + t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr) } else { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error { } } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..f47880dd 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -1,13 +1,14 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "strings" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/routing" ) @@ -19,10 +20,10 @@ type disabledTun struct { // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), @@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) { } t.tx.Inc(1) - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Write payload", "raw", prettyPacket(r)) } return copy(b, r), nil @@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { select { case t.read <- out: default: - t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") + t.l.Debug("tun_disabled: dropped ICMP Echo Reply response") } return true @@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) { // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b)) } - } else if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") + } else if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b)) } return len(b), nil } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 91c51159..3d995553 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/netip" "os" "sync/atomic" @@ -17,8 +18,9 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -93,7 +95,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger fd int shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls @@ -243,7 +245,7 @@ func (t *tun) Close() error { if t.fd >= 0 { if err := unix.Close(t.fd); err != nil { - t.l.WithError(err).Error("Error closing device") + t.l.Error("Error closing device", "error", err) } t.fd = -1 } @@ -264,7 +266,7 @@ func (t *tun) Close() error { err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) } if err != nil { - t.l.WithError(err).Error("Error destroying tunnel") + t.l.Error("Error destroying tunnel", "error", err) } }() @@ -277,11 +279,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error @@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..6bfcbdfb 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync" @@ -14,7 +15,6 @@ import ( "syscall" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -25,14 +25,14 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2830ff6b..c6cfb686 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -17,7 +18,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -213,7 +213,7 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex - l *logrus.Logger + l *slog.Logger } func (t *tun) Networks() []netip.Prefix { @@ -238,7 +238,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err @@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { tfd, err := newTunFd(fd) if err != nil { _ = unix.Close(fd) @@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error { if !initial { if oldMaxMTU != newMaxMTU { t.setMTU() - t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU) } if oldDefaultMTU != newDefaultMTU { for i := range t.vpnNetworks { err := t.setDefaultRoute(t.vpnNetworks[i]) if err != nil { - t.l.Warn(err) + t.l.Warn(err.Error()) } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU) } } } @@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error { } err = netlink.AddrDel(link, &al[i]) if err != nil { - t.l.WithError(err).Error("failed to remove address from tun address list") + t.l.Error("failed to remove address from tun address list", "error", err) } else { - t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + t.l.Info("removed address not listed in cert(s)", "removed", al[i].String()) } } @@ -538,12 +538,12 @@ func (t *tun) Activate() error { ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss - t.l.WithError(err).Error("Failed to set tun tx queue length") + t.l.Error("Failed to set tun tx queue length", "error", err) } const modeNone = 1 if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { - t.l.WithError(err).Warn("Failed to disable link local address generation") + t.l.Warn("Failed to disable link local address generation", "error", err) } if err = t.addIPs(link); err != nil { @@ -582,7 +582,7 @@ func (t *tun) setMTU() { ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") + t.l.Error("Failed to set tun mtu", "error", err) } } @@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&nr) if err != nil { - t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` for i := 0; i < 2; i++ { time.Sleep(100 * time.Millisecond) @@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { if err == nil { break } else { - t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", + "error", err, + "cidr", cidr, + "mtu", t.DefaultMTU, + ) } } if err != nil { @@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) { err := netlink.RouteDel(&nr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } } @@ -721,11 +725,11 @@ func (t *tun) watchRoutes() { netlinkOptions := netlink.RouteSubscribeOptions{ ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, - ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, + ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) }, } if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { - t.l.WithError(err).Errorf("failed to subscribe to system route changes") + t.l.Error("failed to subscribe to system route changes", "error", err) return } @@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { link, err := netlink.LinkByName(t.Device) if err != nil { - t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") + t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device) return gateways } @@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } @@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } } @@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { gateways := t.getGatewaysFromRoute(&r.Route) if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. - t.l.WithField("route", r).Debug("Ignoring route update, no gateways") + t.l.Debug("Ignoring route update, no gateways", "route", r) return } if r.Dst == nil { - t.l.WithField("route", r).Debug("Ignoring route update, no destination address") + t.l.Debug("Ignoring route update, no destination address", "route", r) return } dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + t.l.Debug("Ignoring route update, invalid destination address", "route", r) return } @@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routesFromSystemLock.Lock() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + t.l.Info("Adding route", "destination", dst, "via", gateways) t.routesFromSystem[dst] = gateways newTree.Insert(dst, gateways) } else { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") + t.l.Info("Removing route", "destination", dst, "via", gateways) delete(t.routesFromSystem, dst) newTree.Delete(dst) } @@ -888,18 +892,18 @@ func (t *tun) Close() error { } err := t.readers[i].Close() if err != nil { - t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") + t.l.Error("error closing tun reader", "reader", i, "error", err) } else { - t.l.WithField("reader", i).Info("closed tun reader") + t.l.Info("closed tun reader", "reader", i) } } //this is t.readers[0] too err := t.tunFile.Close() if err != nil { - t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") + t.l.Error("error closing tun reader", "reader", 0, "error", err) } else { - t.l.WithField("reader", 0).Info("closed tun reader") + t.l.Info("closed tun reader", "reader", 0) } return err } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..c971bb6e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -63,18 +63,18 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..81362184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -54,7 +54,7 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..b2c2a0ea 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -4,14 +4,15 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -21,14 +22,14 @@ type TestTun struct { vpnNetworks []netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + l *slog.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) { return } - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } t.rxPackets <- packet } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..680dddb3 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -7,6 +7,7 @@ import ( "crypto" "fmt" "io" + "log/slog" "net/netip" "os" "path/filepath" @@ -16,7 +17,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -33,16 +33,16 @@ type winTun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) @@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") + l.Debug("Failed to create wintun device, retrying", "error", err) tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { return nil, &NameError{ @@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } if !foundDefault4 { @@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error { // See comment on luid.AddRoute err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..e5f27f37 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,14 +2,14 @@ package overlay import ( "io" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/pki.go b/pki.go index 0639fd3d..fb8cc5c6 100644 --- a/pki.go +++ b/pki.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -15,7 +16,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" @@ -24,7 +24,7 @@ import ( type PKI struct { cs atomic.Pointer[CertState] caPool atomic.Pointer[cert.CAPool] - l *logrus.Logger + l *slog.Logger } type CertState struct { @@ -46,7 +46,7 @@ type CertState struct { myVpnBroadcastAddrsTable *bart.Lite } -func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { +func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) { pki := &PKI{l: l} err := pki.reload(c, true) if err != nil { @@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { p.cs.Store(newState) if initial { - p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") + p.l.Debug("Client nebula certificate(s)", "cert", newState) } else { - p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") + p.l.Info("Client certificate(s) refreshed from disk", "cert", newState) } return nil } @@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { } p.caPool.Store(caPool) - p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints()) return nil } @@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { return c, b, nil } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { +func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) { caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") @@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { for _, crt := range caPool.CAs { if crt.Certificate.Expired(time.Now()) { expired++ - l.WithField("cert", crt).Warn("expired certificate present in CA pool") + l.Warn("expired certificate present in CA pool", "cert", crt) } } @@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { caPool.BlocklistFingerprint(fp) } - l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") + l.Info("Blocklisted certificates", "fingerprintCount", len(bl)) } return caPool, nil diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go index 39f648ff..bca23d78 100644 --- a/pki_hup_benchmark_test.go +++ b/pki_hup_benchmark_test.go @@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) { c := config.NewC(l) require.NoError(b, c.Load(dir)) - _, err := NewPKIFromConfig(l, c) + _, err := NewPKIFromConfig(test.NewLogger(), c) require.NoError(b, err) b.ReportAllocs() diff --git a/punchy.go b/punchy.go index 2034405a..6ecf4f85 100644 --- a/punchy.go +++ b/punchy.go @@ -1,10 +1,10 @@ package nebula import ( + "log/slog" "sync/atomic" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) @@ -14,10 +14,10 @@ type Punchy struct { delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *logrus.Logger + l *slog.Logger } -func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { +func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { p := &Punchy{l: l} p.reload(c, true) @@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.respond.Store(yes) if !initial { - p.l.Infof("punchy.respond changed to %v", p.GetRespond()) + p.l.Info("punchy.respond changed", "respond", p.GetRespond()) } } @@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) { if initial || c.HasChanged("punchy.delay") { p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { - p.l.Infof("punchy.delay changed to %s", p.GetDelay()) + p.l.Info("punchy.delay changed", "delay", p.GetDelay()) } } if initial || c.HasChanged("punchy.target_all_remotes") { p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { - p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) } } if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { - p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) + p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) } } } diff --git a/punchy_test.go b/punchy_test.go index 56dd1c25..cbf9b17b 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -1,6 +1,8 @@ package nebula import ( + "context" + "log/slog" "testing" "time" @@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.False(t, p.GetPunch()) assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) @@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) { // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } @@ -62,7 +64,7 @@ punchy: delay: 1m respond: false `)) - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.GetRespond()) @@ -76,3 +78,158 @@ punchy: assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.GetRespond()) } + +// The tests below pin the shape of each log line Punchy produces so changes +// cannot silently break whatever operators are grepping for. The assertions +// are on the structured message + attrs (e.g. "punchy.respond changed" with +// a respond=true field) rather than a formatted string. +// +// Punchy.reload also emits a spurious "Changing punchy.punch with reload is +// not supported" warning whenever any key under punchy changes, because of +// the c.HasChanged("punchy") fallback kept for the deprecated top-level +// punchy form. The tests filter by message rather than asserting total +// entry counts so that warning is tolerated without being locked into +// the format. + +type capturedEntry struct { + Level slog.Level + Msg string + Attrs map[string]any +} + +// capturingHandler is a slog.Handler that records each Record it receives so +// tests can assert on the level, message, and attribute map of individual log +// lines without coupling to any specific text format. +type capturingHandler struct { + entries []capturedEntry +} + +func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } + +func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error { + e := capturedEntry{ + Level: r.Level, + Msg: r.Message, + Attrs: make(map[string]any), + } + r.Attrs(func(a slog.Attr) bool { + e.Attrs[a.Key] = a.Value.Resolve().Any() + return true + }) + h.entries = append(h.entries, e) + return nil +} + +func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h } + +func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) { + t.Helper() + hook := &capturingHandler{} + return slog.New(hook), hook +} + +func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry { + t.Helper() + for _, e := range entries { + if e.Msg == msg { + return e + } + } + t.Fatalf("no entry with message %q among %d entries", msg, len(entries)) + return capturedEntry{} +} + +func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: true}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy enabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy disabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) + + entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") + assert.Equal(t, slog.LevelWarn, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) + + entry := findEntry(t, hook.entries, "punchy.respond changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) + + entry := findEntry(t, hook.entries, "punchy.delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) + + entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) + + entry := findEntry(t, hook.entries, "punchy.respond_delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs) +} diff --git a/relay_manager.go b/relay_manager.go index 91640f24..919bb2b6 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -5,22 +5,22 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net/netip" "sync/atomic" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type relayManager struct { - l *logrus.Logger + l *slog.Logger hostmap *HostMap amRelay atomic.Bool } -func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { +func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { rm := &relayManager{ l: l, hostmap: hostmap, @@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c c.RegisterReloadCallback(func(c *config.C) { err := rm.reload(c, false) if err != nil { - l.WithError(err).Error("Failed to reload relay_manager") + rm.l.Error("Failed to reload relay_manager", "error", err) } }) return rm @@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for range 32 { @@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - fields := logrus.Fields{ - "relay": relayHostInfo.vpnAddrs[0], - "initiatorRelayIndex": m.InitiatorRelayIndex, - } - + var relayFrom, relayTo any if m.RelayFromAddr == nil { - fields["relayFrom"] = m.OldRelayFromAddr + relayFrom = m.OldRelayFromAddr } else { - fields["relayFrom"] = m.RelayFromAddr + relayFrom = m.RelayFromAddr } - if m.RelayToAddr == nil { - fields["relayTo"] = m.OldRelayToAddr + relayTo = m.OldRelayToAddr } else { - fields["relayTo"] = m.RelayToAddr + relayTo = m.RelayToAddr } - rm.l.WithFields(fields).Info("relayManager failed to update relay") + rm.l.Info("relayManager failed to update relay", + "relay", relayHostInfo.vpnAddrs[0], + "initiatorRelayIndex", m.InitiatorRelayIndex, + "relayFrom", relayFrom, + "relayTo", relayTo, + ) return nil, fmt.Errorf("unknown relay") } @@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { msg := &NebulaControl{} err := msg.Unmarshal(d) if err != nil { - h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + h.logger(f.l).Error("Failed to unmarshal control message", "error", err) return } @@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { } func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { - rm.l.WithFields(logrus.Fields{ - "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), - "relayTo": protoAddrToNetAddr(m.RelayToAddr), - "initiatorRelayIndex": m.InitiatorRelayIndex, - "responderRelayIndex": m.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("handleCreateRelayResponse") + rm.l.Info("handleCreateRelayResponse", + "relayFrom", protoAddrToNetAddr(m.RelayFromAddr), + "relayTo", protoAddrToNetAddr(m.RelayToAddr), + "initiatorRelayIndex", m.InitiatorRelayIndex, + "responderRelayIndex", m.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) target := m.RelayToAddr targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { - rm.l.WithError(err).Error("Failed to update relay for relayTo") + rm.l.Error("Failed to update relay for relayTo", "error", err) return } // Do I need to complete the relays now? @@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f // I'm the middle man. Let the initiator know that the I've established the relay they requested. peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") + rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr) return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") + rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0]) return } switch peerRelay.State { @@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { - rm.l.WithField("relayFrom", peer). - WithField("relayTo", target). - WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). - WithField("responderRelayIndex", resp.ResponderRelayIndex). - WithField("vpnAddrs", peerHostInfo.vpnAddrs). - Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address", + "relayFrom", peer, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) return } @@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - rm.l.WithError(err). - Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromAddr, - "relayTo": resp.RelayToAddr, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": peerHostInfo.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", resp.RelayFromAddr, + "relayTo", resp.RelayToAddr, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) } } } @@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f from := protoAddrToNetAddr(m.RelayFromAddr) target := protoAddrToNetAddr(m.RelayToAddr) - logMsg := rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnAddrs": h.vpnAddrs}) + logMsg := rm.l.With( + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", m.InitiatorRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. if f.myVpnAddrsTable.Contains(from) { - logMsg.WithField("myIP", from).Error("Discarding relay request from myself") + logMsg.Error("Discarding relay request from myself", "myIP", from) return } @@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } case Disestablished: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } // Mark the relay as 'Established' because it's safe to use again h.relayState.UpdateRelayForByIpState(from, Established) case PeerRequested: // I should never be in this state, because I am terminal, not forwarding. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex, - "state": existingRelay.State}).Error("Unexpected Relay State found") + logMsg.Error("Unexpected Relay State found", + "existingRemoteIndex", existingRelay.RemoteIndex, + "state", existingRelay.State) } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { - logMsg.WithError(err).Error("Failed to add relay") + logMsg.Error("Failed to add relay", "error", err) return } } relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.WithField("from", from).Error("Relay State not found") + logMsg.Error("Relay State not found", "from", from) return } @@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) } return } else { @@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { - rm.l.WithField("relayFrom", h.vpnAddrs[0]). - WithField("relayTo", target). - WithField("initiatorRelayIndex", req.InitiatorRelayIndex). - WithField("responderRelayIndex", req.ResponderRelayIndex). - WithField("vpnAddr", target). - Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) return } @@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := req.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") + logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": h.vpnAddrs[0], - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddr": target}). - Info("send CreateRelayRequest") + rm.l.Info("send CreateRelayRequest", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) } // Also track the half-created Relay state just received @@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if !ok { _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to allocate a local index for relay") + logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err) return } } diff --git a/remote_list.go b/remote_list.go index 8338d517..7b95de87 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "log/slog" "net" "net/netip" "slices" @@ -10,8 +11,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // forEachFunc is used to benefit folks that want to do work inside the lock @@ -66,11 +65,11 @@ type hostnamesResults struct { network string lookupTimeout time.Duration cancelFn func() - l *logrus.Logger + l *slog.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } -func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { +func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { r := &hostnamesResults{ hostnames: make([]hostnamePort, len(hostPorts)), network: network, @@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) timeoutCancel() if err != nil { - l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + l.Error("DNS resolution failed for static_map host", + "hostname", hostPort.name, + "network", r.network, + "error", err, + ) continue } for _, a := range addrs { @@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } } if different { - l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + l.Info("DNS results changed for host list", + "origSet", origSet, + "newSet", netipAddrs, + ) r.ips.Store(&netipAddrs) onUpdate() } diff --git a/service/service_test.go b/service/service_test.go index c6b87423..4bcc8437 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,11 +10,11 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" @@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n panic(err) } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/ssh.go b/ssh.go index b2912d55..3863b5ec 100644 --- a/ssh.go +++ b/ssh.go @@ -6,21 +6,21 @@ import ( "errors" "flag" "fmt" + "log/slog" "maps" "net" "net/netip" "os" "path/filepath" - "reflect" "runtime" "runtime/pprof" "sort" "strconv" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/sshd" ) @@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct { Pretty bool } -func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { +func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { - l.WithError(err).Error("Failed to reconfigure the sshd") + l.Error("Failed to reconfigure the sshd", "error", err) ssh.Stop() } if sshRun != nil { @@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { +func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, caAuthorizedKey := range rawCAs { err := ssh.AddTrustedCA(caAuthorizedKey) if err != nil { - l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") + l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey) continue } } @@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, rk := range keys { kDef, ok := rk.(map[string]any) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") + l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk) continue } user, ok := kDef["user"].(string) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") + l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk) continue } @@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro case string: err := ssh.AddAuthorizedKey(user, v) if err != nil { - l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", rk, + "sshKey", v, + ) continue } @@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, subK := range v { sk, ok := subK.(string) if !ok { - l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") + l.Warn("Did not understand ssh key", + "sshKeyConfig", rk, + "sshKey", subK, + ) continue } err := ssh.AddAuthorizedKey(user, sk) if err != nil { - l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", sk, + ) continue } } default: - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") + l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk) } } } else { @@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro ssh.Stop() runner = func() { if err := ssh.Run(listen); err != nil { - l.WithField("err", err).Warn("Failed to run the SSH server") + l.Warn("Failed to run the SSH server", "error", err) } } } else { @@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { +func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { // sandboxDir defaults to a dir in temp. The intention is that end user will // create this dir as needed. Overriding this config value to "" allows // writing to anywhere in the system. @@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetLevel() slog.Level + SetLevel(slog.Level) + }) + if !ok { + return w.WriteLine("Log level is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } - level, err := logrus.ParseLevel(a[0]) + level, err := logging.ParseLevel(strings.ToLower(a[0])) if err != nil { - return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) + return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a)) } - l.SetLevel(level) - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + ctrl.SetLevel(level) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } -func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetFormat() string + SetFormat(string) error + }) + if !ok { + return w.WriteLine("Log format is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } - logFormat := strings.ToLower(a[0]) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{} - case "json": - l.Formatter = &logrus.JSONFormatter{} - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil { + return err } - - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { diff --git a/sshd/server.go b/sshd/server.go index 4b5cc3e0..38886e53 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -5,16 +5,16 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) type SSHServer struct { config *ssh.ServerConfig - l *logrus.Entry + l *slog.Logger certChecker *ssh.CertChecker @@ -33,7 +33,7 @@ type SSHServer struct { } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { +func NewSSHServer(l *slog.Logger) (*SSHServer, error) { ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ @@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error { } s.trustedCAs = append(s.trustedCAs, pk) - s.l.WithField("sshKey", pubKey).Info("Trusted CA key") + s.l.Info("Trusted CA key", "sshKey", pubKey) return nil } @@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { } tk[string(pk.Marshal())] = true - s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") + s.l.Info("Authorized ssh key", + "sshKey", pubKey, + "sshUser", user, + ) return nil } @@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error { return err } - s.l.WithField("sshListener", addr).Info("SSH server is listening") + s.l.Info("SSH server is listening", "sshListener", addr) // Run loops until there is an error s.run() @@ -172,7 +175,7 @@ func (s *SSHServer) run() { c, err := s.listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { - s.l.WithError(err).Warn("Error in listener, shutting down") + s.l.Warn("Error in listener, shutting down", "error", err) } return } @@ -193,23 +196,29 @@ func (s *SSHServer) run() { } if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + l := s.l.With( + "error", err, + "remoteAddress", c.RemoteAddr(), + ) if conn != nil { - l = l.WithField("sshUser", conn.User()) + l = l.With("sshUser", conn.User()) conn.Close() } if fp != "" { - l = l.WithField("sshFingerprint", fp) + l = l.With("sshFingerprint", fp) } l.Warn("failed to handshake") sessionCancel() return } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.With("sshUser", conn.User()) + l.Info("ssh user logged in", + "remoteAddress", c.RemoteAddr(), + "sshFingerprint", fp, + ) - NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session")) + NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session")) go ssh.DiscardRequests(reqs) @@ -221,7 +230,7 @@ func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { - s.l.WithError(err).Warn("Failed to close the sshd listener") + s.l.Warn("Failed to close the sshd listener", "error", err) } } } diff --git a/sshd/session.go b/sshd/session.go index 39c81bd0..1c8e1a9b 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -2,25 +2,25 @@ package sshd import ( "fmt" + "log/slog" "sort" "strings" "github.com/anmitsu/go-shlex" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/term" ) type session struct { - l *logrus.Entry + l *slog.Logger c *ssh.ServerConn term *term.Terminal commands *radix.Tree cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, @@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) { defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { - s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") + s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { - s.l.WithError(err).Warn("could not accept channel") + s.l.Warn("could not accept channel", "error", err) continue } @@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { return default: - s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") + s.l.Debug("Rejected unknown request", "sshRequest", req.Type) err = req.Reply(false, nil) } if err != nil { - s.l.WithError(err).Info("Error handling ssh session requests") + s.l.Info("Error handling ssh session requests", "error", err) return } } diff --git a/stats.go b/stats.go index c88c45cc..c7bf3a06 100644 --- a/stats.go +++ b/stats.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "log/slog" "net" "net/http" "runtime" @@ -15,14 +16,13 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) // startStats initializes stats from config. On success, if any further work // is needed to serve stats, it returns a func to handle that work. If no // work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { +func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { return nil, nil @@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b return startFn, nil } -func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { +func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error { proto := c.GetString("stats.protocol", "tcp") host := c.GetString("stats.host", "") if host == "" { @@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe } if !configTest { - l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) + l.Info("Starting graphite", + "interval", i, + "prefix", prefix, + "addr", addr.String(), + ) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) } return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { +func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") @@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV var startFn func() if !configTest { + // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, + // so bridge our slog.Logger back to a *log.Logger that emits at Error. + errLog := slog.NewLogLogger(l.Handler(), slog.LevelError) startFn = func() { - l.Infof("Prometheus stats listening on %s at %s", listen, path) - http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) + l.Info("Prometheus stats listening", + "listen", listen, + "path", path, + ) + http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) log.Fatal(http.ListenAndServe(listen, nil)) } } diff --git a/test/logger.go b/test/logger.go index b5a717d8..faab0b69 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,29 +1,73 @@ package test import ( + "context" "io" + "log/slog" "os" + "time" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -func NewLogger() *logrus.Logger { - l := logrus.New() - +// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to +// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to +// stream output to stderr for local debugging. +func NewLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - return l + return slog.New(slog.DiscardHandler) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by +// w. Timestamps are suppressed so tests can assert on exact output without +// baking the current time into expected strings. +func NewLoggerWithOutput(w io.Writer) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)}) +} + +// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level +// so tests can exercise Enabled-gated paths. +func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with +// timestamps suppressed, for tests that pin the JSON shape. +func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// stripTimeHandler zeros each record's time before delegating so slog's +// built-in handlers skip emitting the time attribute. Used to avoid +// timestamp-dependent assertions in tests without resorting to ReplaceAttr. +type stripTimeHandler struct { + inner slog.Handler +} + +func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool { + return h.inner.Enabled(ctx, l) +} + +func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error { + r.Time = time.Time{} + return h.inner.Handle(ctx, r) +} + +func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)} +} + +func (h *stripTimeHandler) WithGroup(name string) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithGroup(name)} } diff --git a/udp/udp_android.go b/udp/udp_android.go index bb191954..3fc68003 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -9,11 +9,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 65ef31a5..c42a3c18 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -12,11 +12,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 863c98f3..8a4f5b18 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,12 +8,12 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,12 +22,12 @@ type StdConn struct { *net.UDPConn isV4 bool sysFd uintptr - l *logrus.Logger + l *slog.Logger } var _ Conn = &StdConn{} -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error { return err } - u.l.WithError(err).Error("unexpected udp socket receive error") + u.l.Error("unexpected udp socket receive error", "error", err) } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) @@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error { } if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + u.l.Error("Failed to rebind udp socket", "error", err) } return nil diff --git a/udp/udp_generic.go b/udp/udp_generic.go index ad26f794..131eb73b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -12,22 +12,22 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "net/netip" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type GenericConn struct { *net.UDPConn - l *logrus.Logger + l *slog.Logger } var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error { // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 21a34147..3e2d726a 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,13 +7,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,7 +22,7 @@ type StdConn struct { udpConn *net.UDPConn rawConn syscall.RawConn isV4 bool - l *logrus.Logger + l *slog.Logger batch int } @@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error { return opErr } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { listen := netip.AddrPortFrom(ip, uint16(port)) lc := net.ListenConfig{} if multi { @@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetRecvBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.read_buffer was set") + u.l.Info("listen.read_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.read_buffer") + u.l.Warn("Failed to get listen.read_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.read_buffer") + u.l.Error("Failed to set listen.read_buffer", "error", err) } } @@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSendBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.write_buffer was set") + u.l.Info("listen.write_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.write_buffer") + u.l.Warn("Failed to get listen.write_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.write_buffer") + u.l.Error("Failed to set listen.write_buffer", "error", err) } } @@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSoMark() if err == nil { - u.l.WithField("mark", s).Info("listen.so_mark was set") + u.l.Info("listen.so_mark was set", "mark", s) } else { - u.l.WithError(err).Warn("Failed to get listen.so_mark") + u.l.Warn("Failed to get listen.so_mark", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.so_mark") + u.l.Error("Failed to set listen.so_mark", "error", err) } } } diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3b69159a..4b2de75a 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -11,11 +11,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 607b978e..d110af19 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "sync" @@ -17,7 +18,6 @@ import ( "time" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" @@ -53,14 +53,14 @@ type ringBuffer struct { type RIOConn struct { isOpen atomic.Bool - l *logrus.Logger + l *slog.Logger sock windows.Handle rx, tx ringBuffer rq winrio.Rq results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { +func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } @@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro return u, nil } -func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { +func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { var err error u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { @@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") + l.Debug("failed to set UDP_CONNRESET ioctl", "error", err) } ret = 0 @@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") + l.Debug("failed to set UDP_NETRESET ioctl", "error", err) } err = u.rx.Open() @@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error { // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 388b17d0..fcd0967c 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,12 +4,13 @@ package udp import ( + "context" "io" + "log/slog" "net/netip" "os" "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -46,10 +47,10 @@ type TesterConn struct { done chan struct{} closeOnce sync.Once - l *logrus.Logger + l *slog.Logger } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), @@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) { if err := h.Parse(packet.Data); err != nil { panic(err) } - if u.l.Level >= logrus.DebugLevel { - u.l.WithField("header", h). - WithField("udpAddr", packet.From). - WithField("dataLen", len(packet.Data)). - Debug("UDP receiving injected packet") + if u.l.Enabled(context.Background(), slog.LevelDebug) { + u.l.Debug("UDP receiving injected packet", + "header", h, + "udpAddr", packet.From, + "dataLen", len(packet.Data), + ) } select { case <-u.done: diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1b777c37..7969f7e8 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -5,14 +5,13 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - - "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between @@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return rc, nil } - l.WithError(err).Error("Falling back to standard udp sockets") + l.Error("Falling back to standard udp sockets", "error", err) return NewGenericListener(l, ip, port, multi, batch) } diff --git a/util/error.go b/util/error.go index 814c77a1..14371d3f 100644 --- a/util/error.go +++ b/util/error.go @@ -1,10 +1,10 @@ package util import ( + "context" "errors" "fmt" - - "github.com/sirupsen/logrus" + "log/slog" ) type ContextualError struct { @@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error { } // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError -func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { +func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) { switch v := err.(type) { case *ContextualError: v.Log(l) default: - l.WithError(err).Error(msg) + l.Error(msg, "error", err) } } @@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error { return ce.RealError } -func (ce *ContextualError) Log(lr *logrus.Logger) { - if ce.RealError != nil { - lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) - } else { - lr.WithFields(ce.Fields).Error(ce.Context) +// Log emits ce as a single error-level log line with Fields and RealError +// promoted to top-level attributes, producing a flat shape callers can grep +// or parse without walking into a nested object. +func (ce *ContextualError) Log(l *slog.Logger) { + attrs := make([]slog.Attr, 0, len(ce.Fields)+1) + for k, v := range ce.Fields { + attrs = append(attrs, slog.Any(k, v)) } + if ce.RealError != nil { + attrs = append(attrs, slog.Any("error", ce.RealError)) + } + // LogAttrs is intentional: attrs is built from a map[string]any so it has + // no pair-form equivalent. + //nolint:sloglint + l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...) } diff --git a/util/error_test.go b/util/error_test.go index 692c1840..30e39e33 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -1,95 +1,67 @@ package util import ( + "bytes" "errors" "fmt" "testing" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) type m = map[string]any -type TestLogWriter struct { - Logs []string -} - -func NewTestLogWriter() *TestLogWriter { - return &TestLogWriter{Logs: make([]string, 0)} -} - -func (tl *TestLogWriter) Write(p []byte) (n int, err error) { - tl.Logs = append(tl.Logs, string(p)) - return len(p), nil -} - -func (tl *TestLogWriter) Reset() { - tl.Logs = tl.Logs[:0] -} - func TestContextualError_Log(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test a full context line - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test a line with an error and msg but no fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String()) // Test just a context and fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", m{"field": "1"}, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String()) // Test just a context - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String()) // Test just an error - tl.Reset() + buf.Reset() e = NewContextualError("", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String()) } func TestLogWithContextIfNeeded(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test ignoring fallback context - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) LogWithContextIfNeeded("This should get thrown away", e, l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test using fallback context - tl.Reset() + buf.Reset() err := fmt.Errorf("this is a normal error") LogWithContextIfNeeded("Fallback context woo", err, l) - assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String()) } func TestContextualizeIfNeeded(t *testing.T) { From 1ab1f71dba7b5b2f543f444f678db5bd46406bb1 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 27 Apr 2026 12:25:24 -0500 Subject: [PATCH 38/44] Make stats a server we can reconfigure and start/stop (#1670) --- examples/config.yml | 5 + main.go | 4 +- stats.go | 424 ++++++++++++++++++++++++++++++++++---------- stats_test.go | 410 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 744 insertions(+), 99 deletions(-) create mode 100644 stats_test.go diff --git a/examples/config.yml b/examples/config.yml index b02b3d58..f5752ae4 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -304,6 +304,9 @@ logging: #disable_timestamp: true # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable. +# The stats section is reloadable. A HUP may change the backend, toggle stats +# on or off, switch the listen/host address, or pick up new DNS for the +# configured graphite host. #stats: #type: graphite #prefix: nebula @@ -321,10 +324,12 @@ logging: # enables counter metrics for meta packets # e.g.: `messages.tx.handshake` # NOTE: `message.{tx,rx}.recv_error` is always emitted + # Not reloadable. #message_metrics: false # enables detailed counter metrics for lighthouse packets # e.g.: `lighthouse.rx.HostQuery` + # Not reloadable. #lighthouse_metrics: false # Handshake Manager Settings diff --git a/main.go b/main.go index f692f317..eef13c97 100644 --- a/main.go +++ b/main.go @@ -246,7 +246,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev go handshakeManager.Run(ctx) } - statsStart, err := startStats(l, c, buildVersion, configTest) + stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } @@ -266,7 +266,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev ctx: ctx, cancel: cancel, sshStart: sshStart, - statsStart: statsStart, + statsStart: stats.Start, dnsStart: ds.Start, lighthouseStart: lightHouse.StartUpdateWorker, connectionManagerStart: connManager.Start, diff --git a/stats.go b/stats.go index c7bf3a06..97ce7cf5 100644 --- a/stats.go +++ b/stats.go @@ -1,14 +1,16 @@ package nebula import ( + "context" "errors" "fmt" - "log" "log/slog" "net" "net/http" "runtime" "strconv" + "sync" + "sync/atomic" "time" graphite "github.com/cyberdelia/go-metrics-graphite" @@ -19,119 +21,347 @@ import ( "github.com/slackhq/nebula/config" ) -// startStats initializes stats from config. On success, if any further work -// is needed to serve stats, it returns a func to handle that work. If no -// work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { - mType := c.GetString("stats.type", "") - if mType == "" || mType == "none" { - return nil, nil - } +// statsServer owns nebula's stats subsystem: the periodic metric capture +// goroutine and (for prometheus) an HTTP listener. It mirrors the lifecycle +// shape of dnsServer: constructor wires the reload callback, reload records +// config, Start builds and runs the runtime, Stop tears it down. +type statsServer struct { + l *slog.Logger + ctx context.Context + buildVersion string + configTest bool - interval := c.GetDuration("stats.interval", 0) - if interval == 0 { - return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) - } + // enabled mirrors "stats configured to a real backend". Start consults + // it so callers don't need to know the gating rules. + enabled atomic.Bool - var startFn func() - switch mType { - case "graphite": - err := startGraphiteStats(l, interval, c, configTest) - if err != nil { - return nil, err - } - case "prometheus": - var err error - startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("stats.type was not understood: %s", mType) - } - - metrics.RegisterDebugGCStats(metrics.DefaultRegistry) - metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) - - go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) - go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) - - return startFn, nil + runMu sync.Mutex + runCfg *statsConfig + run *statsRuntime // non-nil while a runtime is live } -func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error { - proto := c.GetString("stats.protocol", "tcp") - host := c.GetString("stats.host", "") - if host == "" { - return errors.New("stats.host can not be empty") +// statsRuntime is the live state owned by a single Start invocation. Start +// stashes a pointer under runMu; Stop and Start's own exit path use pointer +// equality to tell "my runtime" apart from one that replaced it after a +// reload. +type statsRuntime struct { + cancel context.CancelFunc + listener *http.Server // nil for graphite +} + +// statsConfig is the snapshot of stats-related config that drives the runtime. +// It is comparable with == so reload can detect "no change" cheaply. +type statsConfig struct { + typ string + interval time.Duration + graphite graphiteConfig + prom promConfig +} + +type graphiteConfig struct { + protocol string + host string + // resolvedAddr is the string form of host resolved at config-load time. + // Including it in the struct means a SIGHUP picks up DNS changes even + // when stats.host hasn't been edited. + resolvedAddr string + prefix string +} + +type promConfig struct { + listen string + path string + namespace string + subsystem string +} + +// newStatsServerFromConfig builds a statsServer, applies the initial config, +// and registers a reload callback. The reload callback is registered before +// the initial config is applied so a SIGHUP can later enable, fix, or disable +// stats even if the initial application failed. +// +// Start is safe to call unconditionally: it no-ops when stats are disabled. +// The returned pointer is always non-nil, even on error. +func newStatsServerFromConfig(ctx context.Context, l *slog.Logger, c *config.C, buildVersion string, configTest bool) (*statsServer, error) { + s := &statsServer{ + l: l, + ctx: ctx, + buildVersion: buildVersion, + configTest: configTest, } - prefix := c.GetString("stats.prefix", "nebula") - addr, err := net.ResolveTCPAddr(proto, host) + c.RegisterReloadCallback(func(c *config.C) { + if err := s.reload(c, false); err != nil { + s.l.Error("Failed to reload stats from config", "error", err) + } + }) + + if err := s.reload(c, true); err != nil { + return s, err + } + return s, nil +} + +// reload records the latest config. On the initial call it only records it; +// Control.Start is what launches the first runtime via statsStart. On later +// calls it reconciles the running runtime with the new config: +// +// - newly enabled -> spawn Start +// - newly disabled -> Stop the runtime +// - config changed (still enabled) -> Stop the old, Start the new +// - no change -> no-op +func (s *statsServer) reload(c *config.C, initial bool) error { + newCfg, err := loadStatsConfig(c) if err != nil { - return fmt.Errorf("error while setting up graphite sink: %s", err) + return err + } + enabled := newCfg.typ != "" && newCfg.typ != "none" + + s.runMu.Lock() + sameCfg := s.runCfg != nil && *s.runCfg == newCfg + s.runCfg = &newCfg + running := s.run != nil + s.runMu.Unlock() + + s.enabled.Store(enabled) + + if initial || sameCfg { + return nil } - if !configTest { - l.Info("Starting graphite", - "interval", i, - "prefix", prefix, - "addr", addr.String(), - ) - go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) + if running { + s.Stop() + } + if enabled && !s.configTest { + go s.Start() } return nil } -func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { - namespace := c.GetString("stats.namespace", "") - subsystem := c.GetString("stats.subsystem", "") - - listen := c.GetString("stats.listen", "") - if listen == "" { - return nil, fmt.Errorf("stats.listen should not be empty") +// Start builds the runtime from the latest config, spawns the capture loop, +// and blocks until Stop is called or ctx fires. For prometheus it also serves +// the HTTP listener. For graphite it blocks on the capture loop's context. +// Safe to call when stats are disabled or already running (both no-op). +func (s *statsServer) Start() { + if !s.enabled.Load() || s.configTest { + return } - path := c.GetString("stats.path", "") - if path == "" { - return nil, fmt.Errorf("stats.path should not be empty") + s.runMu.Lock() + if s.ctx.Err() != nil || s.run != nil || s.runCfg == nil { + s.runMu.Unlock() + return + } + cfg := *s.runCfg + captureFns, listener := s.buildRuntime(cfg) + runCtx, cancel := context.WithCancel(s.ctx) + rt := &statsRuntime{cancel: cancel, listener: listener} + s.run = rt + s.runMu.Unlock() + + go captureStatsLoop(runCtx, cfg.interval, captureFns) + + cleanExit := true + if listener == nil { + // Graphite: no HTTP listener to serve; block until teardown. + <-runCtx.Done() + } else { + cleanExit = s.serveListener(listener) } - pr := prometheus.NewRegistry() - pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) - if !configTest { - go pClient.UpdatePrometheusMetrics() - } - - // Export our version information as labels on a static gauge - g := prometheus.NewGauge(prometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: "info", - Help: "Version information for the Nebula binary", - ConstLabels: prometheus.Labels{ - "version": buildVersion, - "goversion": runtime.Version(), - "boringcrypto": strconv.FormatBool(boringEnabled()), - }, - }) - pr.MustRegister(g) - g.Set(1) - - var startFn func() - if !configTest { - // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, - // so bridge our slog.Logger back to a *log.Logger that emits at Error. - errLog := slog.NewLogLogger(l.Handler(), slog.LevelError) - startFn = func() { - l.Info("Prometheus stats listening", - "listen", listen, - "path", path, - ) - http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) - log.Fatal(http.ListenAndServe(listen, nil)) + // Clear our runtime only if nothing has replaced it. Stop races through + // here too but leaves s.run == nil, so the pointer check skips. + s.runMu.Lock() + if s.run == rt { + rt.cancel() + s.run = nil + // A listener that exited with an error (e.g., bind conflict) leaves + // runCfg cached as if it were applied. Drop it so a SIGHUP with the + // same config re-triggers Start once the user fixes the underlying + // problem. + if !cleanExit { + s.runCfg = nil } } - - return startFn, nil + s.runMu.Unlock() +} + +// serveListener runs ListenAndServe and ensures ctx cancellation unblocks it. +// Returns true if the listener exited cleanly (Stop, ctx cancellation, or any +// other http.ErrServerClosed path), false on an unexpected error. +func (s *statsServer) serveListener(listener *http.Server) bool { + // Per-invocation watcher: ctx cancellation triggers a listener shutdown + // which in turn unblocks ListenAndServe. Closing `done` on exit keeps + // the watcher from outliving this call. + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + case <-done: + } + }() + defer close(done) + + s.l.Info("Starting prometheus stats listener", "addr", listener.Addr) + err := listener.ListenAndServe() + if err == nil || errors.Is(err, http.ErrServerClosed) { + return true + } + s.l.Error("Prometheus stats listener exited", "error", err) + return false +} + +// Stop tears down the active runtime, if any. Idempotent. +func (s *statsServer) Stop() { + s.runMu.Lock() + rt := s.run + s.run = nil + s.runMu.Unlock() + if rt == nil { + return + } + rt.cancel() + if rt.listener != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := rt.listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + cancel() + } +} + +// buildRuntime produces the capture functions and, for prometheus, an un-served +// http.Server from cfg. cfg has already been validated by loadStatsConfig. +func (s *statsServer) buildRuntime(cfg statsConfig) ([]func(), *http.Server) { + // rcrowley/go-metrics guards these registrations with a private sync.Once, + // so subsequent reloads are no-ops. + metrics.RegisterDebugGCStats(metrics.DefaultRegistry) + metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) + + captureFns := []func(){ + func() { metrics.CaptureDebugGCStatsOnce(metrics.DefaultRegistry) }, + func() { metrics.CaptureRuntimeMemStatsOnce(metrics.DefaultRegistry) }, + } + + switch cfg.typ { + case "graphite": + // loadStatsConfig already resolved and validated the address; re-parse + // the resolved form (no DNS lookup) to get a *net.TCPAddr. + addr, _ := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.resolvedAddr) + gcfg := graphite.Config{ + Addr: addr, + Registry: metrics.DefaultRegistry, + FlushInterval: cfg.interval, + DurationUnit: time.Nanosecond, + Prefix: cfg.graphite.prefix, + Percentiles: []float64{0.5, 0.75, 0.95, 0.99, 0.999}, + } + captureFns = append(captureFns, func() { + if err := graphite.Once(gcfg); err != nil { + s.l.Error("Graphite export failed", "error", err) + } + }) + s.l.Info("Starting graphite stats", + "interval", cfg.interval, + "prefix", cfg.graphite.prefix, + "addr", addr, + ) + return captureFns, nil + + case "prometheus": + pr := prometheus.NewRegistry() + pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, cfg.prom.namespace, cfg.prom.subsystem, pr, cfg.interval) + captureFns = append(captureFns, func() { + if err := pClient.UpdatePrometheusMetricsOnce(); err != nil { + s.l.Error("Prometheus metrics update failed", "error", err) + } + }) + + g := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: cfg.prom.namespace, + Subsystem: cfg.prom.subsystem, + Name: "info", + Help: "Version information for the Nebula binary", + ConstLabels: prometheus.Labels{ + "version": s.buildVersion, + "goversion": runtime.Version(), + "boringcrypto": strconv.FormatBool(boringEnabled()), + }, + }) + pr.MustRegister(g) + g.Set(1) + + // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, + // so bridge our slog.Logger back to a *log.Logger that emits at Error. + errLog := slog.NewLogLogger(s.l.Handler(), slog.LevelError) + mux := http.NewServeMux() + mux.Handle(cfg.prom.path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) + return captureFns, &http.Server{Addr: cfg.prom.listen, Handler: mux} + } + return captureFns, nil +} + +// captureStatsLoop runs each fn on every tick of d until ctx is cancelled. +func captureStatsLoop(ctx context.Context, d time.Duration, fns []func()) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + for _, fn := range fns { + fn() + } + } + } +} + +func loadStatsConfig(c *config.C) (statsConfig, error) { + cfg := statsConfig{ + typ: c.GetString("stats.type", ""), + } + if cfg.typ == "" || cfg.typ == "none" { + return cfg, nil + } + + cfg.interval = c.GetDuration("stats.interval", 0) + if cfg.interval == 0 { + return cfg, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) + } + + switch cfg.typ { + case "graphite": + cfg.graphite.protocol = c.GetString("stats.protocol", "tcp") + cfg.graphite.host = c.GetString("stats.host", "") + if cfg.graphite.host == "" { + return cfg, errors.New("stats.host can not be empty") + } + addr, err := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.host) + if err != nil { + return cfg, fmt.Errorf("error while setting up graphite sink: %s", err) + } + cfg.graphite.resolvedAddr = addr.String() + cfg.graphite.prefix = c.GetString("stats.prefix", "nebula") + case "prometheus": + cfg.prom.listen = c.GetString("stats.listen", "") + if cfg.prom.listen == "" { + return cfg, errors.New("stats.listen should not be empty") + } + cfg.prom.path = c.GetString("stats.path", "") + if cfg.prom.path == "" { + return cfg, errors.New("stats.path should not be empty") + } + cfg.prom.namespace = c.GetString("stats.namespace", "") + cfg.prom.subsystem = c.GetString("stats.subsystem", "") + default: + return cfg, fmt.Errorf("stats.type was not understood: %s", cfg.typ) + } + + return cfg, nil } diff --git a/stats_test.go b/stats_test.go new file mode 100644 index 00000000..20b17c0e --- /dev/null +++ b/stats_test.go @@ -0,0 +1,410 @@ +package nebula + +import ( + "context" + "io" + "log/slog" + "net" + "strconv" + "testing" + "time" + + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestStatsServer(t *testing.T) (*statsServer, *config.C) { + t.Helper() + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return &statsServer{ + l: l, + ctx: ctx, + }, config.NewC(l) +} + +func setStatsConfig(c *config.C, m map[string]any) { + c.Settings["stats"] = m +} + +func currentRuntime(s *statsServer) *statsRuntime { + s.runMu.Lock() + defer s.runMu.Unlock() + return s.run +} + +func TestStatsServer_reload_initial_disabled(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_invalidInterval(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "host": "127.0.0.1:0", + "prefix": "test", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_unknownType(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "carbon", + "interval": "1s", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_unchanged_noOp(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + require.NoError(t, s.reload(c, false)) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_graphite(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": "127.0.0.1:2003", + "prefix": "test", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_prometheus(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_Start_graphite_blocksUntilStop(t *testing.T) { + sink := newGraphiteSink(t) + defer sink.Close() + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": sink.Addr(), + "prefix": "test", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + // Wait for Start to publish runtime state. + waitFor(t, func() bool { return currentRuntime(s) != nil }) + rt := currentRuntime(s) + require.NotNil(t, rt) + assert.Nil(t, rt.listener, "graphite has no listener") + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("graphite Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_StartStop_lifecycle(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + waitForListening(t, "127.0.0.1:"+port) + rt := currentRuntime(s) + require.NotNil(t, rt) + require.NotNil(t, rt.listener) + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_disable_stopsRunningRuntime(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + setStatsConfig(c, map[string]any{"type": "none"}) + require.NoError(t, s.reload(c, false)) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled stats") + } + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_changeListener_restartsListener(t *testing.T) { + port1 := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port1, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + firstDone := make(chan struct{}) + go func() { + s.Start() + close(firstDone) + }() + waitForListening(t, "127.0.0.1:"+port1) + first := currentRuntime(s) + require.NotNil(t, first) + + port2 := freeTCPPort(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port2, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, false)) + + select { + case <-firstDone: + case <-time.After(5 * time.Second): + t.Fatal("old Start did not return after reload") + } + + waitForListening(t, "127.0.0.1:"+port2) + second := currentRuntime(s) + require.NotNil(t, second) + assert.NotSame(t, first, second, "expected a new runtime after listen address change") + + s.Stop() +} + +func TestStatsServer_Stop_beforeStart_doesNotBlock(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + stopped := make(chan struct{}) + go func() { + s.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung with no runtime started") + } +} + +func TestStatsServer_configTest_validatesWithoutSpawning(t *testing.T) { + s, c := newTestStatsServer(t) + s.configTest = true + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + s.Start() + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_ctxCancel_unblocksStart(t *testing.T) { + // Ensures ctx cancellation alone (no explicit Stop) tears down both + // graphite and prom Start invocations. + port := freeTCPPort(t) + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := &statsServer{l: l, ctx: ctx} + c := config.NewC(l) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + cancel() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after ctx cancel") + } +} + +func TestStatsServer_listenerBindFailure_sameCfgReloadRetries(t *testing.T) { + // Hold the port so ListenAndServe will fail on first Start. + blocker, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := strconv.Itoa(blocker.Addr().(*net.TCPAddr).Port) + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after bind failure") + } + // Bind failure should have dropped the cached config so a same-cfg + // SIGHUP can retry. + s.runMu.Lock() + cfgAfterFailure := s.runCfg + s.runMu.Unlock() + assert.Nil(t, cfgAfterFailure) + + // Free the port and reload with the same config; Start should fire again. + require.NoError(t, blocker.Close()) + require.NoError(t, s.reload(c, false)) + + waitForListening(t, "127.0.0.1:"+port) + require.NotNil(t, currentRuntime(s)) + + s.Stop() +} + +func waitForListening(t *testing.T, addr string) { + t.Helper() + waitFor(t, func() bool { + conn, err := net.DialTimeout("tcp", addr, 200*time.Millisecond) + if err != nil { + return false + } + _ = conn.Close() + return true + }) +} + +// graphiteSink is a minimal TCP accept-and-discard server so graphite.Once +// calls in tests don't spam error logs or wedge on connection refused. +type graphiteSink struct { + ln net.Listener +} + +func newGraphiteSink(t *testing.T) *graphiteSink { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + g := &graphiteSink{ln: ln} + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + _, _ = io.Copy(io.Discard, c) + _ = c.Close() + }(conn) + } + }() + return g +} + +func (g *graphiteSink) Addr() string { return g.ln.Addr().String() } +func (g *graphiteSink) Close() { _ = g.ln.Close() } + +func freeTCPPort(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := ln.Addr().(*net.TCPAddr).Port + require.NoError(t, ln.Close()) + return strconv.Itoa(port) +} From 9ec8cf10f34d881dc96af9659ed305430afcb24e Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 30 Apr 2026 21:30:27 -0500 Subject: [PATCH 39/44] Handshake state machine (#1656) --- cert_test/cert.go | 52 +++ connection_manager_test.go | 17 +- connection_state.go | 67 +-- connection_state_test.go | 114 ++++++ firewall_test.go | 2 +- handshake/credential.go | 57 +++ handshake/errors.go | 21 + handshake/handshake.proto | 29 ++ handshake/helpers_test.go | 116 ++++++ handshake/machine.go | 444 ++++++++++++++++++++ handshake/machine_test.go | 662 ++++++++++++++++++++++++++++++ handshake/patterns.go | 54 +++ handshake/patterns_test.go | 63 +++ handshake/payload.go | 173 ++++++++ handshake/payload_test.go | 361 ++++++++++++++++ handshake_ix.go | 813 ------------------------------------- handshake_manager.go | 631 ++++++++++++++++++++++++++-- handshake_manager_test.go | 137 ++++++- nebula.pb.go | 677 ++---------------------------- nebula.proto | 18 +- pki.go | 121 ++++-- 21 files changed, 3036 insertions(+), 1593 deletions(-) create mode 100644 connection_state_test.go create mode 100644 handshake/credential.go create mode 100644 handshake/errors.go create mode 100644 handshake/handshake.proto create mode 100644 handshake/helpers_test.go create mode 100644 handshake/machine.go create mode 100644 handshake/machine_test.go create mode 100644 handshake/patterns.go create mode 100644 handshake/patterns_test.go create mode 100644 handshake/payload.go create mode 100644 handshake/payload_test.go delete mode 100644 handshake_ix.go diff --git a/cert_test/cert.go b/cert_test/cert.go index 75134316..c3759f12 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) { pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } + +// DummyCert is a minimal cert.Certificate implementation for testing error paths. +type DummyCert struct { + Version_ cert.Version + Curve_ cert.Curve + Groups_ []string + IsCA_ bool + Issuer_ string + Name_ string + Networks_ []netip.Prefix + NotAfter_ time.Time + NotBefore_ time.Time + PublicKey_ []byte + Signature_ []byte + UnsafeNetworks_ []netip.Prefix +} + +func (d *DummyCert) Version() cert.Version { return d.Version_ } +func (d *DummyCert) Curve() cert.Curve { return d.Curve_ } +func (d *DummyCert) Groups() []string { return d.Groups_ } +func (d *DummyCert) IsCA() bool { return d.IsCA_ } +func (d *DummyCert) Issuer() string { return d.Issuer_ } +func (d *DummyCert) Name() string { return d.Name_ } +func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ } +func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ } +func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ } +func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ } +func (d *DummyCert) Signature() []byte { return d.Signature_ } +func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ } +func (d *DummyCert) Fingerprint() (string, error) { return "", nil } +func (d *DummyCert) CheckSignature(key []byte) bool { return false } +func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil } +func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil } +func (d *DummyCert) String() string { return "dummy" } +func (d *DummyCert) Copy() cert.Certificate { return d } +func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil } +func (d *DummyCert) Expired(time.Time) bool { return false } +func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil } +func (d *DummyCert) PublicKeyPEM() []byte { return nil } + +// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error. +func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool { + pool := cert.NewCAPool() + for _, ca := range cas { + if err := pool.AddCA(ca); err != nil { + panic(err) + } + } + return pool +} diff --git a/connection_manager_test.go b/connection_manager_test.go index a015fba9..7dc08a45 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/overlaytest" @@ -47,7 +46,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -80,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -130,7 +128,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -163,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -215,7 +212,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -249,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -340,9 +336,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - privateKey: []byte{}, - v1Cert: &dummyCert{}, - v1HandshakeBytes: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -372,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, - H: &noise.HandshakeState{}, }, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) diff --git a/connection_state.go b/connection_state.go index b85aebd4..47e23b5a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -1,15 +1,12 @@ package nebula import ( - "crypto/rand" "encoding/json" - "fmt" "sync" "sync/atomic" - "github.com/flynn/noise" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/handshake" ) const ReplayWindow = 1024 @@ -17,7 +14,6 @@ const ReplayWindow = 1024 type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState - H *noise.HandshakeState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -26,55 +22,24 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { - var dhFunc noise.DHFunc - switch crt.Curve() { - case cert.Curve_CURVE25519: - dhFunc = noise.DH25519 - case cert.Curve_P256: - if cs.pkcs11Backed { - dhFunc = noiseutil.DHP256PKCS11 - } else { - dhFunc = noiseutil.DHP256 - } - default: - return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) - } - - var ncs noise.CipherSuite - if cs.cipher == "chachapoly" { - ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) - } else { - ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) - } - - static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} - hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: ncs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - //NOTE: These should come from CertState (pki.go) when we finally implement it - PresharedKey: []byte{}, - PresharedKeyPlacement: 0, - }) - if err != nil { - return nil, fmt.Errorf("NewConnectionState: %s", err) - } - - // The queue and ready params prevent a counter race that would happen when - // sending stored packets and simultaneously accepting new traffic. +// newConnectionStateFromResult builds a fully-populated ConnectionState from a +// completed handshake.Result. It seeds messageCounter and the replay window so +// that the post-handshake message indices already used on the wire don't count +// as missed traffic in the data plane. +func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { ci := &ConnectionState{ - H: hs, - initiator: initiator, + myCert: r.MyCert, + initiator: r.Initiator, + peerCert: r.RemoteCert, + eKey: NewNebulaCipherState(r.EKey), + dKey: NewNebulaCipherState(r.DKey), window: NewBits(ReplayWindow), - myCert: crt, } - // always start the counter from 2, as packet 1 and packet 2 are handshake packets. - ci.messageCounter.Add(2) - - return ci, nil + ci.messageCounter.Add(r.MessageIndex) + for i := uint64(1); i <= r.MessageIndex; i++ { + ci.window.Update(nil, i) + } + return ci } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { diff --git a/connection_state_test.go b/connection_state_test.go new file mode 100644 index 00000000..dea60d39 --- /dev/null +++ b/connection_state_test.go @@ -0,0 +1,114 @@ +package nebula + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// runTestHandshake runs a complete IX handshake between two freshly-built +// peers and returns the initiator and responder Results. Used to produce +// real cipher states for tests that need to exercise post-handshake glue. +func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) { + t.Helper() + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc { + c, _, rawKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey) + require.NoError(t, err) + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + cred := handshake.NewCredential(c, hsBytes, priv, ncs) + return func(v cert.Version) *handshake.Credential { + if v == cert.Version2 { + return cred + } + return nil + } + } + + verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) { + return caPool.VerifyCertificate(time.Now(), c) + } + + initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := handshake.NewMachine( + cert.Version2, initCreds, verifier, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + respM, err := handshake.NewMachine( + cert.Version2, respCreds, verifier, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respR, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respR) + + _, initR, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initR) + + return initR, respR +} + +func TestNewConnectionStateFromResult(t *testing.T) { + initR, respR := runTestHandshake(t) + + t.Run("initiator", func(t *testing.T) { + ci := newConnectionStateFromResult(initR) + assert.True(t, ci.initiator) + assert.Equal(t, initR.MyCert, ci.myCert) + assert.Equal(t, initR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + + // IX has 2 handshake messages; the next data-plane send is counter=3. + assert.Equal(t, uint64(2), ci.messageCounter.Load(), + "messageCounter must equal Result.MessageIndex so the next send is N+1") + + // Both handshake counters must be marked seen so they don't appear lost. + // Check returns false if an index has already been recorded. + assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen") + assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen") + // Counter 3 is the next data-plane message and must NOT be pre-marked. + assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded") + }) + + t.Run("responder", func(t *testing.T) { + ci := newConnectionStateFromResult(respR) + assert.False(t, ci.initiator) + assert.Equal(t, respR.MyCert, ci.myCert) + assert.Equal(t, respR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + assert.Equal(t, uint64(2), ci.messageCounter.Load()) + }) +} diff --git a/firewall_test.go b/firewall_test.go index cbf090fd..40b57477 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1033,7 +1033,7 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes") require.NoError(t, err) conf := config.NewC(test.NewLogger()) diff --git a/handshake/credential.go b/handshake/credential.go new file mode 100644 index 00000000..f6cd5f41 --- /dev/null +++ b/handshake/credential.go @@ -0,0 +1,57 @@ +package handshake + +import ( + "crypto/rand" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" +) + +// Credential holds everything needed to participate in a handshake +// at a given cert version. Version and Curve are read from Cert; the public +// half of the static keypair likewise comes from Cert.PublicKey(). +type Credential struct { + Cert cert.Certificate // the certificate + Bytes []byte // pre-marshaled certificate bytes + privateKey []byte // static private key (public half lives in Cert) + cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash) +} + +// NewCredential creates a Credential with all material needed for handshake +// participation. The cipherSuite should be pre-built by the caller with the +// appropriate DH function, cipher, and hash. +func NewCredential( + c cert.Certificate, + hsBytes []byte, + privateKey []byte, + cipherSuite noise.CipherSuite, +) *Credential { + return &Credential{ + Cert: c, + Bytes: hsBytes, + privateKey: privateKey, + cipherSuite: cipherSuite, + } +} + +// buildHandshakeState creates a noise.HandshakeState from this credential. +func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) { + return noise.NewHandshakeState(noise.Config{ + CipherSuite: hc.cipherSuite, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()}, + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, + }) +} + +// GetCredentialFunc returns the handshake credential for the given version, +// or nil if that version is not available. +// +// Implementations must return credentials drawn from a snapshot stable for +// the lifetime of any single Machine. The Machine may call this multiple +// times during a handshake (e.g. when negotiating to the peer's version) +// and assumes the underlying static keypair is consistent across calls. +type GetCredentialFunc func(v cert.Version) *Credential diff --git a/handshake/errors.go b/handshake/errors.go new file mode 100644 index 00000000..bb8a5893 --- /dev/null +++ b/handshake/errors.go @@ -0,0 +1,21 @@ +package handshake + +import "errors" + +var ( + ErrInitiateOnResponder = errors.New("initiate called on responder") + ErrInitiateAlreadyCalled = errors.New("initiate already called") + ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators") + ErrPacketTooShort = errors.New("packet too short") + ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake") + ErrIncompleteHandshake = errors.New("handshake completed without receiving required content") + ErrMachineFailed = errors.New("handshake machine has failed") + ErrUnknownSubtype = errors.New("unknown handshake subtype") + ErrMissingContent = errors.New("expected handshake content but message was empty") + ErrUnexpectedContent = errors.New("received unexpected handshake content") + ErrIndexAllocation = errors.New("failed to allocate local index") + ErrNoCredential = errors.New("no handshake credential available for cert version") + ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key") + ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager") + ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype") +) diff --git a/handshake/handshake.proto b/handshake/handshake.proto new file mode 100644 index 00000000..8eb32aa6 --- /dev/null +++ b/handshake/handshake.proto @@ -0,0 +1,29 @@ +// This file documents the wire format the nebula handshake speaks. It is +// not run through protoc; the encoder/decoder in payload.go is hand-written +// against this shape directly to keep the parser narrow and panic-free. +// +// Any change to the wire format must be reflected here, and adding a new +// field requires updating MarshalPayload / unmarshalPayloadDetails together +// with the field-uniqueness and wire-type checks in those functions. + +syntax = "proto3"; +package nebula.handshake; + +message NebulaHandshake { + NebulaHandshakeDetails Details = 1; + bytes Hmac = 2; +} + +message NebulaHandshakeDetails { + bytes Cert = 1; + uint32 InitiatorIndex = 2; + uint32 ResponderIndex = 3; + // Cookie was reserved for an anti-DoS mechanism that was never + // implemented. No released version of nebula has ever populated it; the + // hand-written parser silently skips it on read. + uint64 Cookie = 4 [deprecated = true]; + uint64 Time = 5; + uint32 CertVersion = 8; + // reserved for WIP multiport + reserved 6, 7; +} diff --git a/handshake/helpers_test.go b/handshake/helpers_test.go new file mode 100644 index 00000000..c72346cb --- /dev/null +++ b/handshake/helpers_test.go @@ -0,0 +1,116 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/require" +) + +// testCertState holds cert material for a test peer. +type testCertState struct { + version cert.Version + creds map[cert.Version]*Credential +} + +func (s *testCertState) getCredential(v cert.Version) *Credential { + return s.creds[v] +} + +func newTestCertState( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, +) *testCertState { + return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly) +} + +func newTestCertStateWithCipher( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, + cipher noise.CipherFunc, +) *testCertState { + t.Helper() + c, _, rawPrivKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey) + require.NoError(t, err) + + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + + ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256) + return &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(c, hsBytes, priv, ncs), + }, + } +} + +func testVerifier(pool *cert.CAPool) CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return pool.VerifyCertificate(time.Now(), c) + } +} + +func newTestMachine( + t *testing.T, + cs *testCertState, + verifier CertVerifier, + initiator bool, + localIndex uint32, +) *Machine { + t.Helper() + m, err := NewMachine( + cs.version, cs.getCredential, + verifier, func() (uint32, error) { return localIndex, nil }, + initiator, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + return m +} + +func initiateHandshake( + t *testing.T, + initCS *testCertState, initVerifier CertVerifier, + respCS *testCertState, respVerifier CertVerifier, +) (initM, respM *Machine, respResult *Result, resp []byte, err error) { + t.Helper() + initM = newTestMachine(t, initCS, initVerifier, true, 100) + msg1, merr := initM.Initiate(nil) + require.NoError(t, merr) + + respM = newTestMachine(t, respCS, respVerifier, false, 200) + resp, respResult, err = respM.ProcessPacket(nil, msg1) + return +} + +func doFullHandshake( + t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool, +) (initResult, respResult *Result) { + t.Helper() + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respResult, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respResult) + require.NotEmpty(t, resp) + + _, initResult, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + + return initResult, respResult +} diff --git a/handshake/machine.go b/handshake/machine.go new file mode 100644 index 00000000..25ed3a5a --- /dev/null +++ b/handshake/machine.go @@ -0,0 +1,444 @@ +package handshake + +import ( + "bytes" + "fmt" + "slices" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/header" +) + +// IndexAllocator is called by the Machine to allocate a local index for the +// handshake. It is called at most once, when the first outgoing message that +// carries a payload is built. +// +// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning +// "no index assigned" on the wire and in the payload-presence checks. If an +// allocator ever returned 0, a legitimate handshake's payload could be +// indistinguishable from an empty one and would be rejected. +type IndexAllocator func() (uint32, error) + +// CertVerifier is called by the Machine after reconstructing the peer's +// certificate from the handshake. The verifier performs all validation +// (CA trust, expiry, policy checks, allow lists). +type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) + +// Result contains the results of a successful handshake. +// Returned by ProcessPacket when the handshake is complete. +type Result struct { + EKey *noise.CipherState + DKey *noise.CipherState + MyCert cert.Certificate + RemoteCert *cert.CachedCertificate + RemoteIndex uint32 + LocalIndex uint32 + HandshakeTime uint64 + MessageIndex uint64 // number of messages exchanged during the handshake + Initiator bool +} + +// Machine drives a Noise handshake through N messages. It handles Noise +// protocol operations, certificate reconstruction, and payload encoding. +// Certificate validation is delegated to the caller via CertVerifier. +// +// A Machine is not safe for concurrent use. The caller must ensure that +// Initiate and ProcessPacket are not called concurrently. +// +// Error contract: when ProcessPacket or Initiate returns an error, callers +// must check Failed() to decide what to do next. If Failed() is false the +// underlying noise state was not advanced (the packet was rejected before +// ReadMessage took effect, or the rejection is non-fatal like a stale +// retransmit) and the Machine can accept another packet. If Failed() is +// true the Machine is unrecoverable and the caller must abandon it. +type Machine struct { + hs *noise.HandshakeState + getCred GetCredentialFunc + allocIndex IndexAllocator + verifier CertVerifier + result *Result + msgs []msgFlags + myVersion cert.Version + subtype header.MessageSubType + indexAllocated bool + remoteCertSet bool + payloadSet bool + failed bool +} + +// NewMachine creates a handshake state machine. The subtype determines both +// the noise pattern and the per-message content layout. The credential for +// `version` is fetched via getCred and used to seed the noise.HandshakeState. +// IndexAllocator is called lazily when the first outgoing payload is built. +func NewMachine( + version cert.Version, + getCred GetCredentialFunc, + verifier CertVerifier, + allocIndex IndexAllocator, + initiator bool, + subtype header.MessageSubType, +) (*Machine, error) { + info, err := subtypeInfoFor(subtype) + if err != nil { + return nil, err + } + + cred := getCred(version) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, version) + } + + hs, err := cred.buildHandshakeState(initiator, info.pattern) + if err != nil { + return nil, fmt.Errorf("build noise state: %w", err) + } + + return &Machine{ + hs: hs, + subtype: subtype, + msgs: info.msgs, + getCred: getCred, + allocIndex: allocIndex, + verifier: verifier, + myVersion: version, + result: &Result{ + Initiator: initiator, + }, + }, nil +} + +// Failed returns true if the Machine is in an unrecoverable state. +func (m *Machine) Failed() bool { + return m.failed +} + +// Subtype returns the handshake subtype this Machine was built for. +func (m *Machine) Subtype() header.MessageSubType { + return m.subtype +} + +// MessageIndex returns the noise handshake message index, which equals the +// wire counter of the most recently sent or received message. +func (m *Machine) MessageIndex() int { + return m.hs.MessageIndex() +} + +// requireComplete checks that both a peer cert and payload have been received. +// Marks the machine as failed if not. +func (m *Machine) requireComplete() error { + if !m.payloadSet || !m.remoteCertSet { + m.failed = true + return ErrIncompleteHandshake + } + return nil +} + +// myMsgFlags returns the flags for the current outgoing message. +func (m *Machine) myMsgFlags() msgFlags { + idx := m.hs.MessageIndex() + if idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// peerMsgFlags returns the flags for the message we just read. +func (m *Machine) peerMsgFlags() msgFlags { + idx := m.hs.MessageIndex() - 1 + if idx >= 0 && idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// Initiate produces the first handshake message. Only valid for initiators, +// and must be called exactly once before ProcessPacket. +// +// out is a destination buffer the message is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. +// +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) Initiate(out []byte) ([]byte, error) { + if m.failed { + return nil, ErrMachineFailed + } + if !m.result.Initiator { + m.failed = true + return nil, ErrInitiateOnResponder + } + if m.hs.MessageIndex() != 0 { + m.failed = true + return nil, ErrInitiateAlreadyCalled + } + + // At MessageIndex=0 with RemoteIndex still zero, buildResponse produces + // header counter 1 and remote index 0, which is what the initial message needs. + out, _, _, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, err + } + return out, nil +} + +// ProcessPacket handles an incoming handshake message. It advances the Noise +// state, validates the peer certificate via the verifier, and optionally +// produces a response. +// +// out is a destination buffer the response is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. The returned slice +// is nil when no outgoing message is produced (handshake complete on this +// side, or final message of a multi-message pattern). +// +// Returns a non-nil Result when the handshake is complete. +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) { + if m.failed { + return nil, nil, ErrMachineFailed + } + if len(packet) < header.Len { + return nil, nil, ErrPacketTooShort + } + // Reject packets whose subtype doesn't match the one this Machine was + // built for. A pending handshake that suddenly receives a different + // subtype on its index is either a stray packet that matched by chance + // or a peer protocol violation; drop it without failing the Machine so + // the legitimate retransmit can still complete. + if header.MessageSubType(packet[1]) != m.subtype { + return nil, nil, ErrSubtypeMismatch + } + if m.result.Initiator && m.hs.MessageIndex() == 0 { + m.failed = true + return nil, nil, ErrInitiateNotCalled + } + + // The (eKey, dKey) ordering here is correct for IX, where the initiator + // completes the handshake by reading the responder's stage-2 message. + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + // For 3-message patterns where a responder finishes by reading the final + // message, this ordering would be wrong; revisit when XX/pqIX lands. + msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:]) + if err != nil { + // Noise ReadMessage failed. The noise library checkpoints and rolls back + // on failure, so the Machine is still alive. The caller can retry with + // a different packet. + return nil, nil, fmt.Errorf("noise ReadMessage: %w", err) + } + + // From here on, noise state has advanced. Any error is fatal. + flags := m.peerMsgFlags() + + if err := m.processPayload(msg, flags); err != nil { + return nil, nil, err + } + + // If ReadMessage derived keys, the handshake is complete. Noise should + // always produce both keys together; asymmetry is a protocol invariant + // violation. + if eKey != nil || dKey != nil { + if eKey == nil || dKey == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return nil, m.completed(eKey, dKey), nil + } + + // ReadMessage didn't complete, produce the next outgoing message + out, dk, ek, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, nil, err + } + + if ek != nil || dk != nil { + if ek == nil || dk == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return out, m.completed(ek, dk), nil + } + + return out, nil, nil +} + +func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result { + m.result.EKey = eKey + m.result.DKey = dKey + m.result.MessageIndex = uint64(m.hs.MessageIndex()) + return m.result +} + +func (m *Machine) processPayload(msg []byte, flags msgFlags) error { + if len(msg) == 0 { + if flags.expectsPayload || flags.expectsCert { + m.failed = true + return ErrMissingContent + } + return nil + } + + payload, err := UnmarshalPayload(msg) + if err != nil { + m.failed = true + return fmt.Errorf("unmarshal handshake: %w", err) + } + + // Assert the payload contains exactly what we expect + hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 + if hasPayloadData != flags.expectsPayload { + m.failed = true + return ErrUnexpectedContent + } + + hasCertData := len(payload.Cert) > 0 + if hasCertData != flags.expectsCert { + m.failed = true + return ErrUnexpectedContent + } + + // Process payload + if flags.expectsPayload { + if m.result.Initiator { + m.result.RemoteIndex = payload.ResponderIndex + } else { + m.result.RemoteIndex = payload.InitiatorIndex + } + m.result.HandshakeTime = payload.Time + m.payloadSet = true + } + + // Process certificate + if flags.expectsCert { + if err := m.validateCert(payload); err != nil { + return err + } + } + + return nil +} + +func (m *Machine) validateCert(payload Payload) error { + cred := m.getCred(m.myVersion) + if cred == nil { + m.failed = true + return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + rc, err := cert.Recombine( + cert.Version(payload.CertVersion), + payload.Cert, + m.hs.PeerStatic(), + cred.Cert.Curve(), + ) + if err != nil { + m.failed = true + return fmt.Errorf("recombine cert: %w", err) + } + + if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) { + m.failed = true + return ErrPublicKeyMismatch + } + + // Version negotiation, if the peer sent a different version and we have it, switch + if rc.Version() != m.myVersion { + if m.getCred(rc.Version()) != nil { + m.myVersion = rc.Version() + } + } + + verified, err := m.verifier(rc) + if err != nil { + m.failed = true + return fmt.Errorf("verify cert: %w", err) + } + + m.result.RemoteCert = verified + m.remoteCertSet = true + return nil +} + +func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) { + if !flags.expectsPayload && !flags.expectsCert { + return nil, nil + } + + var p Payload + if flags.expectsPayload { + if !m.indexAllocated { + index, err := m.allocIndex() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err) + } + m.result.LocalIndex = index + m.indexAllocated = true + } + + if m.result.Initiator { + p.InitiatorIndex = m.result.LocalIndex + } else { + p.ResponderIndex = m.result.LocalIndex + p.InitiatorIndex = m.result.RemoteIndex + } + p.Time = uint64(time.Now().UnixNano()) + } + if flags.expectsCert { + cred := m.getCred(m.myVersion) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + p.Cert = cred.Bytes + p.CertVersion = uint32(cred.Cert.Version()) + m.result.MyCert = cred.Cert + } + + return MarshalPayload(nil, p), nil +} + +func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) { + flags := m.myMsgFlags() + hsBytes, err := m.marshalOutgoing(flags) + if err != nil { + return nil, nil, nil, err + } + + // Extend out by header.Len to make room for the header. slices.Grow is a + // no-op when the cap is already sufficient (the zero-copy case where the + // caller passed a pre-sized buffer). header.Encode overwrites the new + // bytes, so they don't need to be zeroed. + start := len(out) + out = slices.Grow(out, header.Len)[:start+header.Len] + header.Encode( + out[start:], + header.Version, header.Handshake, m.subtype, + m.result.RemoteIndex, + uint64(m.hs.MessageIndex()+1), + ) + + // noise.WriteMessage appends the encrypted handshake message to out, + // reusing capacity when present. + // + // The (dKey, eKey) ordering here is correct for IX, where the responder + // completes the handshake by writing the stage-2 message. noise returns + // (cs1, cs2) where cs1 is the initiator->responder cipher (which is the + // responder's decrypt key). For 3-message patterns where an initiator + // finishes by writing the final message, this ordering would be wrong; + // revisit when XX/pqIX lands. + out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err) + } + + return out, dKey, eKey, nil +} diff --git a/handshake/machine_test.go b/handshake/machine_test.go new file mode 100644 index 00000000..722a39e1 --- /dev/null +++ b/handshake/machine_test.go @@ -0,0 +1,662 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/noiseutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMachineIXHappyPath(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name()) + assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint32(1000), initR.LocalIndex) + assert.Equal(t, uint32(2000), initR.RemoteIndex) + assert.Equal(t, uint32(2000), respR.LocalIndex) + assert.Equal(t, uint32(1000), respR.RemoteIndex) + + assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages") + assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages") + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("world"), pt2) +} + +func TestMachineInitiateErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("initiate on responder", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateOnResponder) + assert.True(t, m.Failed()) + }) + + t.Run("initiate called twice", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, err := m.Initiate(nil) + require.NoError(t, err) + _, err = m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateAlreadyCalled) + assert.True(t, m.Failed()) + }) + + t.Run("process packet before initiate on initiator", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, _, err := m.ProcessPacket(nil, make([]byte, 100)) + require.ErrorIs(t, err, ErrInitiateNotCalled) + assert.True(t, m.Failed()) + }) + + t.Run("calling failed machine", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) // fails: responder + require.Error(t, err) + _, err = m.Initiate(nil) // fails: already failed + require.ErrorIs(t, err, ErrMachineFailed) + }) +} + +func TestMachineProcessPacketErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("packet too short", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, _, err := m.ProcessPacket(nil, []byte{1, 2, 3}) + require.ErrorIs(t, err, ErrPacketTooShort) + assert.False(t, m.Failed(), "short packet should not kill machine") + }) + + t.Run("noise decryption failure is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + resp, _, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + + corrupted := make([]byte, len(resp)) + copy(corrupted, resp) + for i := header.Len; i < len(corrupted); i++ { + corrupted[i] ^= 0xff + } + _, _, err = initM.ProcessPacket(nil, corrupted) + require.Error(t, err) + assert.False(t, initM.Failed(), "noise failure should be recoverable") + + // And the machine should still complete a real handshake afterward. + _, result, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, result, "initiator should complete on the legitimate response") + }) + + t.Run("invalid cert is fatal", func(t *testing.T) { + otherCA, _, otherCAKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, msg1) + require.Error(t, err) + assert.True(t, respM.Failed(), "cert validation failure should kill machine") + }) + + t.Run("subtype mismatch is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + // Mutate the subtype byte (offset 1 in the header) to a value the + // responder Machine wasn't built for. + bad := make([]byte, len(msg1)) + copy(bad, msg1) + bad[1] = 0xff + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, bad) + require.ErrorIs(t, err, ErrSubtypeMismatch) + assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine") + + // And the machine should still complete a real handshake afterward. + resp, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet") + assert.NotEmpty(t, resp, "responder should produce a stage-2 reply") + }) +} + +// TestMachineProcessPayload exercises processPayload's internal validation +// directly. Most of these failure modes can't be reached black-box once the +// subtype check at the top of ProcessPacket gates external callers, so we +// drive them by hand here for coverage. +func TestMachineProcessPayload(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("empty message with expects fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrMissingContent) + assert.True(t, m.Failed()) + }) + + t.Run("empty message with no expects passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{}) + require.NoError(t, err) + assert.False(t, m.Failed()) + }) + + t.Run("malformed protobuf is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true}) + require.Error(t, err) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected payload data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with index data when none was expected. + bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected cert data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with cert when none was expected. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("missing payload data when expected is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // Cert present, but no index/time fields. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) +} + +// TestMachineRequireComplete checks the fail-on-incomplete-handshake path +// directly. Like processPayload above this isn't reachable from a normal IX +// flow, so we drive it by hand. +func TestMachineRequireComplete(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("missing both fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("payload only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("cert only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.remoteCertSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("both set passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + m.remoteCertSet = true + err := m.requireComplete() + require.NoError(t, err) + assert.False(t, m.Failed()) + }) +} + +func TestMachineAESCipher(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertStateWithCipher( + t, ca, caKey, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + noiseutil.CipherAESGCM, + ) + respCS := newTestCertStateWithCipher( + t, ca, caKey, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + noiseutil.CipherAESGCM, + ) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("works"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("back"), pt2) +} + +func TestResultFields(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.True(t, initR.Initiator) + assert.False(t, respR.Initiator) + assert.NotZero(t, initR.HandshakeTime) + assert.NotZero(t, respR.HandshakeTime) + assert.NotNil(t, initR.RemoteCert) + assert.NotNil(t, respR.RemoteCert) +} + +func TestMachineBufferReuse(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + t.Run("response writes into provided buffer", func(t *testing.T) { + buf := make([]byte, 0, 4096) + resp, result, err := respM.ProcessPacket(buf, msg1) + require.NoError(t, err) + require.NotNil(t, result) + + assert.NotEmpty(t, resp, "response should have content") + assert.Equal(t, &buf[:1][0], &resp[:1][0], + "response should reuse the provided buffer's backing array") + }) + + t.Run("initiate writes into provided buffer", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 3000) + buf := make([]byte, 0, 4096) + msg, err := initM2.Initiate(buf) + require.NoError(t, err) + + assert.NotEmpty(t, msg, "initiate should have content") + assert.Equal(t, &buf[:1][0], &msg[:1][0], + "initiate should reuse the provided buffer's backing array") + }) + + t.Run("nil out still works", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 4000) + respM2 := newTestMachine(t, respCS, v, false, 5000) + + msg1, err := initM2.Initiate(nil) + require.NoError(t, err) + + resp, _, err := respM2.ProcessPacket(nil, msg1) + require.NoError(t, err) + + out, result, err := initM2.ProcessPacket(nil, resp) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Nil(t, out, "initiator should have no response for IX msg2") + }) +} + +func TestMachineMsgIndexTracking(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 100) + respM := newTestMachine(t, respCS, v, false, 200) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp1, result1, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.NotNil(t, result1) + + _, result2, err := initM.ProcessPacket(nil, resp1) + require.NoError(t, err) + assert.NotNil(t, result2) +} + +func TestMachineThreeMessagePattern(t *testing.T) { + registerTestXXInfo(t) + + // Use HandshakeXX (3 messages) to verify the Machine handles multi-message + // patterns correctly. XX flow: + // msg1 (I->R): [E] - payload only, no cert + // msg2 (R->I): [E, ee, S, es] - payload + cert + // msg3 (I->R): [S, se] - cert only (no payload, not first two) + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + v := testVerifier(caPool) + + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := NewMachine( + cert.Version2, + initCS.getCredential, v, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + respM, err := NewMachine( + cert.Version2, + respCS.getCredential, v, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + // msg1: initiator -> responder (E only, no cert) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + assert.NotEmpty(t, msg1) + + // Responder processes msg1, should not complete yet, should produce msg2 + msg2, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.Nil(t, result, "XX should not complete on msg1") + assert.NotEmpty(t, msg2, "responder should produce msg2") + + // Initiator processes msg2: gets responder's cert, produces msg3, and + // completes (WriteMessage for msg3 derives keys) + msg3, initResult, err := initM.ProcessPacket(nil, msg2) + require.NoError(t, err) + require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3") + assert.NotEmpty(t, msg3, "initiator should produce msg3") + assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name()) + + // Responder processes msg3: gets initiator's cert and completes + _, respResult, err := respM.ProcessPacket(nil, msg3) + require.NoError(t, err) + require.NotNil(t, respResult, "XX responder should complete on msg3") + assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages") + assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages") + + // Verify keys work + ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages")) + require.NoError(t, err) + pt1, err := respResult.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("three messages"), pt1) +} + +// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with +// IX since the cert is always in the payload. A 3-message pattern test (HybridIX) +// should exercise the case where cert arrives in msg3 and verify that completing +// without it fails. + +func TestMachineExpiredCert(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, + time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour), + nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + expCert, _, expKeyPEM, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + "expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour), + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil, + ) + expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM) + require.NoError(t, err) + expHsBytes, err := expCert.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + expiredCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, expiredCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineNoCertNetworks(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + caHsBytes, err := ca.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + noNetCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, noNetCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.Error(t, err) + assert.True(t, respM.Failed()) +} + +func TestMachineDifferentCAs(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + + initCS := newTestCertState( + t, ca1, caKey1, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := newTestCertState( + t, ca2, caKey2, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, initCS, testVerifier(ct.NewTestCAPool(ca1)), + respCS, testVerifier(ct.NewTestCAPool(ca2)), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineVersionNegotiation(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca1, ca2) + + makeMultiVersionResp := func(t *testing.T) *testCertState { + t.Helper() + respCertV1, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2) + respHsV1, _ := respCertV1.MarshalForHandshakes() + respHsV2, _ := respCertV2.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + return &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs), + cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs), + }, + } + } + + t.Run("responder matches initiator version", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := makeMultiVersionResp(t) + v := testVerifier(caPool) + + initM, _, respResult, resp, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version2, respResult.MyCert.Version(), + "responder should negotiate to initiator's version") + + _, initResult, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(), + "initiator should see V2 cert from responder") + }) + + t.Run("responder keeps version when no match available", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + + respCert, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respHs, _ := respCert.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + respCS := &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCert, respHs, respKey, ncs), + }, + } + + v := testVerifier(caPool) + _, _, respResult, _, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version1, respResult.MyCert.Version(), + "responder should keep V1 when V2 not available") + }) +} diff --git a/handshake/patterns.go b/handshake/patterns.go new file mode 100644 index 00000000..a0cc1a70 --- /dev/null +++ b/handshake/patterns.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "fmt" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" +) + +// msgFlags tracks what application data a handshake message carries. +type msgFlags struct { + expectsPayload bool // message carries indexes and time + expectsCert bool // message carries the certificate +} + +// subtypeInfo bundles the noise pattern with the per-message flags for a +// given handshake subtype. +type subtypeInfo struct { + pattern noise.HandshakePattern + msgs []msgFlags +} + +// subtypeInfos defines the noise pattern and message content layout for each +// handshake subtype. +var subtypeInfos = map[header.MessageSubType]subtypeInfo{ + // IX: 2 messages, both carry payload and cert + header.HandshakeIXPSK0: { + pattern: noise.HandshakeIX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: true}, + {expectsPayload: true, expectsCert: true}, + }, + }, + + // XX: 3 messages + // msg1 (I->R): payload only + // msg2 (R->I): payload + cert + // msg3 (I->R): cert only + //header.HandshakeXXPSK0: { + // pattern: noise.HandshakeXX, + // msgs: []msgFlags{ + // {expectsPayload: true, expectsCert: false}, + // {expectsPayload: true, expectsCert: true}, + // {expectsPayload: false, expectsCert: true}, + // }, + //}, +} + +func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) { + if info, ok := subtypeInfos[subtype]; ok { + return info, nil + } + return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype) +} diff --git a/handshake/patterns_test.go b/handshake/patterns_test.go new file mode 100644 index 00000000..d6207e00 --- /dev/null +++ b/handshake/patterns_test.go @@ -0,0 +1,63 @@ +package handshake + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubtypeInfo(t *testing.T) { + t.Run("IX", func(t *testing.T) { + info, err := subtypeInfoFor(header.HandshakeIXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name) + require.Len(t, info.msgs, 2) + // msg1: payload + cert + assert.True(t, info.msgs[0].expectsPayload) + assert.True(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + }) + + t.Run("XX", func(t *testing.T) { + registerTestXXInfo(t) + info, err := subtypeInfoFor(header.HandshakeXXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name) + require.Len(t, info.msgs, 3) + // msg1: payload only + assert.True(t, info.msgs[0].expectsPayload) + assert.False(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + // msg3: cert only + assert.False(t, info.msgs[2].expectsPayload) + assert.True(t, info.msgs[2].expectsCert) + }) + + t.Run("unknown subtype returns error", func(t *testing.T) { + _, err := subtypeInfoFor(99) + require.ErrorIs(t, err, ErrUnknownSubtype) + }) +} + +// registerTestXXInfo temporarily registers XX subtype info for testing. +func registerTestXXInfo(t *testing.T) { + t.Helper() + subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{ + pattern: noise.HandshakeXX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: false}, + {expectsPayload: true, expectsCert: true}, + {expectsPayload: false, expectsCert: true}, + }, + } + t.Cleanup(func() { + delete(subtypeInfos, header.HandshakeXXPSK0) + }) +} diff --git a/handshake/payload.go b/handshake/payload.go new file mode 100644 index 00000000..4567fc0d --- /dev/null +++ b/handshake/payload.go @@ -0,0 +1,173 @@ +package handshake + +import ( + "errors" + "math" + + "google.golang.org/protobuf/encoding/protowire" +) + +var ( + errInvalidHandshakeMessage = errors.New("invalid handshake message") + errInvalidHandshakeDetails = errors.New("invalid handshake details") +) + +// Payload represents the decoded fields of a handshake message. +// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +type Payload struct { + Cert []byte + InitiatorIndex uint32 + ResponderIndex uint32 + Time uint64 + CertVersion uint32 +} + +// Proto field numbers for NebulaHandshakeDetails +const ( + fieldCert = 1 // bytes + fieldInitiatorIndex = 2 // uint32 + fieldResponderIndex = 3 // uint32 + fieldTime = 5 // uint64 + fieldCertVersion = 8 // uint32 +) + +// MarshalPayload encodes a handshake payload in protobuf wire format compatible +// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +// Returns out (which may be nil), with the marshalled Payload appended to it. +func MarshalPayload(out []byte, p Payload) []byte { + var details []byte + + if len(p.Cert) > 0 { + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = protowire.AppendBytes(details, p.Cert) + } + if p.InitiatorIndex != 0 { + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.InitiatorIndex)) + } + if p.ResponderIndex != 0 { + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.ResponderIndex)) + } + if p.Time != 0 { + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = protowire.AppendVarint(details, p.Time) + } + if p.CertVersion != 0 { + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.CertVersion)) + } + + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + + return out +} + +// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. +func UnmarshalPayload(b []byte) (Payload, error) { + var p Payload + + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + + switch { + case num == 1 && typ == protowire.BytesType: + details, n := protowire.ConsumeBytes(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + if err := unmarshalPayloadDetails(&p, details); err != nil { + return p, err + } + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + } + } + + return p, nil +} + +func unmarshalPayloadDetails(p *Payload, b []byte) error { + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + + // For known field numbers, reject any non-matching wire type as a + // hard error rather than silently skipping. The caller will catch + // missing-field cases downstream, but a wire-type mismatch on a tag + // we know is a peer protocol violation worth flagging here. + // Repeated occurrences of a singular field follow proto3 last-wins. + switch num { + case fieldCert: + if typ != protowire.BytesType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Cert = append([]byte(nil), v...) + b = b[n:] + case fieldInitiatorIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.InitiatorIndex = uint32(v) + b = b[n:] + case fieldResponderIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.ResponderIndex = uint32(v) + b = b[n:] + case fieldTime: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Time = v + b = b[n:] + case fieldCertVersion: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.CertVersion = uint32(v) + b = b[n:] + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + } + } + return nil +} diff --git a/handshake/payload_test.go b/handshake/payload_test.go new file mode 100644 index 00000000..2ff3231c --- /dev/null +++ b/handshake/payload_test.go @@ -0,0 +1,361 @@ +package handshake + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestPayloadRoundTrip(t *testing.T) { + t.Run("all fields set", func(t *testing.T) { + data := MarshalPayload(nil, Payload{ + Cert: []byte("test-cert-bytes"), + CertVersion: 2, + InitiatorIndex: 12345, + ResponderIndex: 67890, + Time: 1234567890, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, []byte("test-cert-bytes"), got.Cert) + assert.Equal(t, uint32(12345), got.InitiatorIndex) + assert.Equal(t, uint32(67890), got.ResponderIndex) + assert.Equal(t, uint64(1234567890), got.Time) + assert.Equal(t, uint32(2), got.CertVersion) + }) + + t.Run("minimal fields", func(t *testing.T) { + data := MarshalPayload(nil, Payload{InitiatorIndex: 1}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(1), got.InitiatorIndex) + assert.Equal(t, uint32(0), got.ResponderIndex) + assert.Equal(t, uint64(0), got.Time) + assert.Nil(t, got.Cert) + }) + + t.Run("empty payload", func(t *testing.T) { + data := MarshalPayload(nil, Payload{}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("large cert bytes", func(t *testing.T) { + bigCert := make([]byte, 4096) + for i := range bigCert { + bigCert[i] = byte(i % 256) + } + + data := MarshalPayload(nil, Payload{ + Cert: bigCert, + CertVersion: 2, + InitiatorIndex: 999, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, bigCert, got.Cert) + assert.Equal(t, uint32(999), got.InitiatorIndex) + }) + + t.Run("append to existing buffer", func(t *testing.T) { + prefix := []byte("prefix") + data := MarshalPayload(prefix, Payload{InitiatorIndex: 42}) + + assert.Equal(t, []byte("prefix"), data[:6]) + + got, err := UnmarshalPayload(data[6:]) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) +} + +func TestPayloadUnknownFields(t *testing.T) { + t.Run("unknown field in outer message is skipped", func(t *testing.T) { + // Marshal a normal payload then append an unknown field (field 99, varint) + data := MarshalPayload(nil, Payload{InitiatorIndex: 42}) + data = protowire.AppendTag(data, 99, protowire.VarintType) + data = protowire.AppendVarint(data, 12345) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("unknown field in details is skipped", func(t *testing.T) { + // Build details with a known field + unknown field + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 77) + // Unknown field 50, varint + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = protowire.AppendVarint(details, 9999) + // Another known field after the unknown one + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 88) + + // Wrap in outer message + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(77), got.InitiatorIndex) + assert.Equal(t, uint32(88), got.ResponderIndex) + }) + + t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) { + // Fields 6 and 7 are reserved in the proto definition + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 100) + details = protowire.AppendTag(details, 6, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, 7, protowire.VarintType) + details = protowire.AppendVarint(details, 2) + + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(100), got.InitiatorIndex) + }) +} + +func TestPayloadBytesConsumed(t *testing.T) { + t.Run("all bytes consumed on valid input", func(t *testing.T) { + original := Payload{ + Cert: []byte("cert"), + CertVersion: 2, + InitiatorIndex: 100, + ResponderIndex: 200, + Time: 999, + } + data := MarshalPayload(nil, original) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + // Re-marshal and compare — proves we consumed and reproduced all fields + remarshaled := MarshalPayload(nil, got) + assert.Equal(t, data, remarshaled) + }) +} + +// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope +// so UnmarshalPayload can reach unmarshalPayloadDetails. +func wrapDetails(details []byte) []byte { + var out []byte + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + return out +} + +func TestPayloadUnmarshalErrors(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got, err := UnmarshalPayload(nil) + require.NoError(t, err) + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("truncated outer tag", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x80}) + assert.Error(t, err) + }) + + t.Run("truncated outer details field", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05}) + assert.Error(t, err) + }) + + t.Run("truncated outer unknown field", func(t *testing.T) { + // Valid tag for unknown field 99 varint, but no value follows + var data []byte + data = protowire.AppendTag(data, 99, protowire.VarintType) + _, err := UnmarshalPayload(data) + assert.Error(t, err) + }) + + t.Run("truncated details tag", func(t *testing.T) { + _, err := UnmarshalPayload(wrapDetails([]byte{0x80})) + assert.Error(t, err) + }) + + t.Run("truncated cert bytes", func(t *testing.T) { + // Field 1 (cert), bytes type, length 10 but only 2 bytes + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated initiator index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated responder index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated time varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated cert version varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated unknown field in details", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert with wrong wire type rejected", func(t *testing.T) { + // fieldCert as Varint instead of Bytes. + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("initiator index with wrong wire type rejected", func(t *testing.T) { + // fieldInitiatorIndex as Bytes instead of Varint. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("time with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) { + // Per proto3, multiple instances of a singular field are accepted and + // the last value wins. We keep this behavior so that peers using + // alternative encoders aren't rejected. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + got, err := UnmarshalPayload(wrapDetails(details)) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("initiator index varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + +} + +// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it +// never panics, and for any input that parses cleanly, that re-marshal + +// re-parse is a fix-point. Inputs come from an authenticated peer (post- +// noise-decrypt), so the threat model is "valid peer behaving arbitrarily," +// not "unauthenticated injection." +func FuzzPayload(f *testing.F) { + // Seed corpus with a handful of known-good shapes. + f.Add(MarshalPayload(nil, Payload{})) + f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})) + f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})) + f.Add(MarshalPayload(nil, Payload{ + Cert: []byte("seed-cert"), + InitiatorIndex: 1, + ResponderIndex: 2, + Time: 3, + CertVersion: 2, + })) + f.Add([]byte{}) + f.Add([]byte{0xff}) + + f.Fuzz(func(t *testing.T, data []byte) { + p1, err := UnmarshalPayload(data) + if err != nil { + return + } + + // For any input that parses, re-marshaling and re-parsing must + // yield an equivalent Payload. This catches dispatch bugs (e.g. + // emitting a field on marshal that we don't accept on parse) and + // any non-idempotent parsing behavior. + b2 := MarshalPayload(nil, p1) + p2, err := UnmarshalPayload(b2) + if err != nil { + t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2) + } + if !payloadsEqual(p1, p2) { + t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2) + } + }) +} + +func payloadsEqual(a, b Payload) bool { + return bytes.Equal(a.Cert, b.Cert) && + a.InitiatorIndex == b.InitiatorIndex && + a.ResponderIndex == b.ResponderIndex && + a.Time == b.Time && + a.CertVersion == b.CertVersion +} diff --git a/handshake_ix.go b/handshake_ix.go deleted file mode 100644 index a086960e..00000000 --- a/handshake_ix.go +++ /dev/null @@ -1,813 +0,0 @@ -package nebula - -import ( - "bytes" - "context" - "log/slog" - "net/netip" - "time" - - "github.com/flynn/noise" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/header" -) - -// NOISE IX Handshakes - -// This function constructs a handshake packet, but does not actually send it -// Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { - err := f.handshakeManager.allocateIndex(hh) - if err != nil { - f.l.Error("Failed to generate index", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - cs := f.pki.getCertState() - v := cs.initiatingVersion - if hh.initiatingVersionOverride != cert.VersionPre1 { - v = hh.initiatingVersionOverride - } else if v < cert.Version2 { - // If we're connecting to a v6 address we should encourage use of a V2 cert - for _, a := range hh.hostinfo.vpnAddrs { - if a.Is6() { - v = cert.Version2 - break - } - } - } - - crt := cs.getCertificate(v) - if crt == nil { - f.l.Error("Unable to handshake with host because no certificate is available", - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - - crtHs := cs.getHandshakeBytes(v) - if crtHs == nil { - f.l.Error("Unable to handshake with host because no certificate handshake bytes is available", - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - - ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX) - if err != nil { - f.l.Error("Failed to create connection state", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - hh.hostinfo.ConnectionState = ci - - hs := &NebulaHandshake{ - Details: &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: crtHs, - CertVersion: uint32(v), - }, - } - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.Error("Failed to marshal handshake message", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "certVersion", v, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - - msg, _, _, err := ci.H.WriteMessage(h, hsBytes) - if err != nil { - f.l.Error("Failed to call noise.WriteMessage", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - // We are sending handshake packet 1, so we don't expect to receive - // handshake packet 1 from the responder - ci.window.Update(f.l, 1) - - hh.hostinfo.HandshakePacket[0] = msg - hh.ready = true - return true -} - -func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { - cs := f.pki.getCertState() - crt := cs.GetDefaultCertificate() - if crt == nil { - f.l.Error("Unable to handshake with host because no certificate is available", - "from", via, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", cs.initiatingVersion, - ) - return - } - - ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX) - if err != nil { - f.l.Error("Failed to create connection state", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(f.l, 1) - - msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.Error("Failed to call noise.ReadMessage", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.Error("Failed unmarshal handshake message", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.Info("Handshake did not contain a certificate", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, fperr := rc.Fingerprint() - if fperr != nil { - fp = "" - } - - attrs := []slog.Attr{ - slog.Any("error", err), - slog.Any("from", via), - slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}), - slog.Any("certVpnNetworks", rc.Networks()), - slog.String("certFingerprint", fp), - } - if f.l.Enabled(context.Background(), slog.LevelDebug) { - attrs = append(attrs, slog.Any("cert", rc)) - } - - // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that - // callers grow conditionally, which has no pair-form equivalent. - //nolint:sloglint - f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) - return - } - - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.Info("public key mismatch between certificate and handshake", - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "cert", remoteCert, - ) - return - } - - if remoteCert.Certificate.Version() != ci.myCert.Version() { - // We started off using the wrong certificate version, lets see if we can match the version that was sent to us - myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) - if myCertOtherVersion == nil { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("Might be unable to handshake with host due to missing certificate version", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "cert", remoteCert, - ) - } - } else { - // Record the certificate we are actually using - ci.myCert = myCertOtherVersion - } - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.Info("No networks in certificate", - "error", err, - "from", via, - "cert", remoteCert, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - vpnNetworks := remoteCert.Certificate.Networks() - - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.Error("Refusing to handshake with myself", - "vpnNetworks", vpnNetworks, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - } - - if !via.IsRelayed { - // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", - "vpnAddrs", vpnAddrs, - "from", via, - ) - } - return - } - } - - myIndex, err := generateIndex(f.l) - if err != nil { - f.l.Error("Failed to generate index", - "error", err, - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hostinfo := &HostInfo{ - ConnectionState: ci, - localIndexId: myIndex, - remoteIndexId: hs.Details.InitiatorIndex, - vpnAddrs: vpnAddrs, - HandshakePacket: make(map[uint8][]byte, 0), - lastHandshakeTime: hs.Details.Time, - relayState: RelayState{ - relays: nil, - relayForByAddr: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - - msgRxL := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - hs.Details.ResponderIndex = myIndex - hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) - if hs.Details.Cert == nil { - msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available", - "myCertVersion", ci.myCert.Version(), - ) - return - } - - hs.Details.CertVersion = uint32(ci.myCert.Version()) - // Update the time in case their clock is way off from ours - hs.Details.Time = uint64(time.Now().UnixNano()) - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.Error("Failed to marshal handshake message", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) - if err != nil { - f.l.Error("Failed to call noise.WriteMessage", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } else if dKey == nil || eKey == nil { - f.l.Error("Noise did not arrive at a key", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:])) - copy(hostinfo.HandshakePacket[0], packet[header.Len:]) - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - hostinfo.HandshakePacket[2] = make([]byte, len(msg)) - copy(hostinfo.HandshakePacket[2], msg) - - // We are sending handshake packet 2, so we don't expect to receive - // handshake packet 2 from the initiator. - ci.window.Update(f.l, 2) - - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) - if err != nil { - switch err { - case ErrAlreadySeen: - // Update remote if preferred - if existing.SetRemoteIfPreferred(f.hostMap, via) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - msg = existing.HandshakePacket[2] - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - err := f.outside.WriteTo(msg, via.UdpAddr) - if err != nil { - f.l.Error("Failed to send handshake message", - "vpnAddrs", existing.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - "error", err, - ) - } else { - f.l.Info("Handshake message sent", - "vpnAddrs", existing.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - ) - } - return - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.Info("Handshake message sent", - "vpnAddrs", existing.vpnAddrs, - "relay", via.relayHI.vpnAddrs[0], - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - ) - return - } - case ErrExistingHostInfo: - // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.Info("Handshake too old", - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "oldHandshakeTime", existing.lastHandshakeTime, - "newHandshakeTime", hostinfo.lastHandshakeTime, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - - // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - return - case ErrLocalIndexCollision: - // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.Error("Failed to add HostInfo due to localIndex collision", - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "localIndex", hostinfo.localIndexId, - "collision", existing.vpnAddrs, - ) - return - default: - // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete - // And we forget to update it here - f.l.Error("Failed to add HostInfo to HostMap", - "error", err, - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - } - - // Do the send - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - err = f.outside.WriteTo(msg, via.UdpAddr) - log := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - if err != nil { - log.Error("Failed to send handshake", "error", err) - } else { - log.Info("Handshake message sent") - } - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure - // it's correctly marked as working. - via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.Info("Handshake message sent", - "vpnAddrs", vpnAddrs, - "relay", via.relayHI.vpnAddrs[0], - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - } - - f.connectionManager.AddTrafficWatch(hostinfo) - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - - // Don't wait for UpdateWorker - if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { - f.lightHouse.TriggerUpdate() - } - - return -} - -func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { - if hh == nil { - // Nothing here to tear down, got a bogus stage 2 packet - return true - } - - hh.Lock() - defer hh.Unlock() - - hostinfo := hh.hostinfo - if !via.IsRelayed { - // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - ) - } - return false - } - } - - ci := hostinfo.ConnectionState - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.Error("Failed to call noise.ReadMessage", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "header", h, - ) - - // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying - // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the - // near future - return false - } else if dKey == nil || eKey == nil { - f.l.Error("Noise did not arrive at a key", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // This should be impossible in IX but just in case, if we get here then there is no chance to recover - // the handshake state machine. Tear it down - return true - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.Error("Failed unmarshal handshake message", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again - return true - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.Info("Handshake did not contain a certificate", - "error", err, - "from", via, - "vpnAddrs", hostinfo.vpnAddrs, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - return true - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, err := rc.Fingerprint() - if err != nil { - fp = "" - } - - attrs := []slog.Attr{ - slog.Any("error", err), - slog.Any("from", via), - slog.Any("vpnAddrs", hostinfo.vpnAddrs), - slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}), - slog.String("certFingerprint", fp), - slog.Any("certVpnNetworks", rc.Networks()), - } - if f.l.Enabled(context.Background(), slog.LevelDebug) { - attrs = append(attrs, slog.Any("cert", rc)) - } - - // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that - // callers grow conditionally, which has no pair-form equivalent. - //nolint:sloglint - f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) - return true - } - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.Info("public key mismatch between certificate and handshake", - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cert", remoteCert, - ) - return true - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.Info("No networks in certificate", - "error", err, - "from", via, - "vpnAddrs", hostinfo.vpnAddrs, - "cert", remoteCert, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - return true - } - - vpnNetworks := remoteCert.Certificate.Networks() - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - - hostinfo.remoteIndexId = hs.Details.ResponderIndex - hostinfo.lastHandshakeTime = hs.Details.Time - - // Store their cert and our symmetric keys - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - // Make sure the current udpAddr being used is set for responding - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - } - - correctHostResponded := false - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - if hostinfo.vpnAddrs[0] == network.Addr() { - // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? - correctHostResponded = true - } - } - - // Ensure the right host responded - if !correctHostResponded { - f.l.Info("Incorrect host responded to handshake", - "intendedVpnAddrs", hostinfo.vpnAddrs, - "haveVpnNetworks", vpnNetworks, - "from", via, - "certName", certName, - "certVersion", certVersion, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // Release our old handshake from pending, it should not continue - f.handshakeManager.DeleteHostInfo(hostinfo) - - // Create a new hostinfo/handshake for the intended vpn ip - //TODO is hostinfo.vpnAddrs[0] always the address to use? - f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { - // Block the current used address - newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(via) - - f.l.Info("Blocked addresses for handshakes", - "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(), - "vpnNetworks", vpnNetworks, - "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()), - ) - - // Swap the packet store to benefit the original intended recipient - newHH.packetStore = hh.packetStore - hh.packetStore = []*cachedPacket{} - - // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnAddrs = vpnAddrs - f.sendCloseTunnel(hostinfo) - }) - - return true - } - - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) - - duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "durationNs", duration, - "sentCachedPackets", len(hh.packetStore), - ) - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - // Build up the radix for the firewall if we have subnets in the cert - hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here - f.handshakeManager.Complete(hostinfo, f) - f.connectionManager.AddTrafficWatch(hostinfo) - - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Sending stored packets", - "count", len(hh.packetStore), - ) - } - - if len(hh.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range hh.packetStore { - cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) - } - f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) - } - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - f.metricHandshakes.Update(duration) - - // Don't wait for UpdateWorker - if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { - f.lightHouse.TriggerUpdate() - } - - return false -} diff --git a/handshake_manager.go b/handshake_manager.go index 8040ec2e..9fc69ff4 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -14,6 +14,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -23,6 +24,18 @@ const ( DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 DefaultUseRelays = true + + // maxCachedPackets is how many unsent packets we'll buffer per pending + // handshake before dropping further ones. + maxCachedPackets = 100 + + // HandshakePacket map keys mirror the IX protocol stage convention: + // stage 0 = the initiator's first message (and what the responder + // receives, stripped of header) + // stage 2 = the responder's reply + // Other handshake patterns will need new keys when added. + handshakePacketStage0 uint8 = 0 + handshakePacketStage2 uint8 = 2 ) var ( @@ -76,10 +89,11 @@ type HandshakeHostInfo struct { packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo + machine *handshake.Machine // The handshake state machine, set during stage 0 (initiator) or beginHandshake (responder multi-message) } func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - if len(hh.packetStore) < 100 { + if len(hh.packetStore) < maxCachedPackets { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) @@ -137,6 +151,18 @@ func (hm *HandshakeManager) Run(ctx context.Context) { } func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { + // Gate on known handshake subtypes. Unknown subtypes (or future ones we + // don't yet support) are dropped here rather than silently routed through + // the IX path. Add a case when introducing a new pattern. + switch h.Subtype { + case header.HandshakeIXPSK0: + // supported + default: + hm.l.Debug("dropping handshake with unsupported subtype", + "from", via, "subtype", h.Subtype) + return + } + // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { @@ -145,19 +171,27 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head } } - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(hm.f, via, packet, h) - - case 2: - newHostinfo := hm.queryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - hm.DeleteHostInfo(newHostinfo.hostinfo) - } + // First message of a new handshake. The wire format requires RemoteIndex + // to be zero here (the initiator has no responder index to fill in yet), + // and generateIndex never allocates 0, so any non-zero RemoteIndex on a + // stage-1 packet is malformed or someone probing for an index collision. + // Drop without paying the cost of running noise on a pending Machine. + if h.MessageCounter == 1 { + if h.RemoteIndex != 0 { + hm.l.Debug("dropping stage-1 handshake with non-zero RemoteIndex", + "from", via, "remoteIndex", h.RemoteIndex) + return } + hm.beginHandshake(via, packet, h) + return + } + + // Continuation message must match a pending handshake by index. + // Anything else is an orphaned packet (e.g., late retransmit after + // timeout) and is dropped. + if hh := hm.queryIndex(h.RemoteIndex); hh != nil { + hm.continueHandshake(via, hh, packet) + return } } @@ -183,13 +217,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).Info("Handshake timed out", + fields := []any{ "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "initiatorIndex", hh.hostinfo.localIndexId, "remoteIndex", hh.hostinfo.remoteIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, "durationNs", time.Since(hh.startTime).Nanoseconds(), - ) + } + // hh.machine can be nil here if buildStage0Packet never succeeded + // (e.g., no certificate available). In that case there's no useful + // handshake metadata to log. + if hh.machine != nil { + fields = append(fields, "handshake", m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + }) + } + hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields...) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -200,12 +243,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Check if we have a handshake packet to transmit yet if !hh.ready { - if !ixHandshakeStage0(hm.f, hh) { + if !hm.buildStage0Packet(hh) { hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } + // TODO: this hardcodes "always retransmit stage 0", which is correct for + // IX (the initiator only ever sends one packet, msg1) but wrong the + // moment a 3+ message pattern lands. The retry loop should resend the + // most recent outgoing message, not always stage 0. That implies + // HandshakeHostInfo tracking a single "currentOutbound" packet (bytes + + // header metadata) that gets replaced as the handshake progresses, + // instead of indexing into HandshakePacket. + stage0 := hostinfo.HandshakePacket[handshakePacketStage0] + hsFields := m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + } + // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -239,13 +295,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []netip.AddrPort hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { - hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hm.messageMetrics.Tx(header.Handshake, hh.machine.Subtype(), 1) + err := hm.outside.WriteTo(stage0, addr) if err != nil { hostinfo.logger(hm.l).Error("Failed to send handshake message", "udpAddr", addr, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, "error", err, ) @@ -260,13 +316,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).Info("Handshake message sent", "udpAddrs", sentTo, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, ) } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(hm.l).Debug("Handshake message sent", "udpAddrs", sentTo, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, ) } @@ -348,7 +404,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered switch existingRelay.State { case Established: hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) @@ -587,7 +643,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) (uint32, error) { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() @@ -596,7 +652,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { for range 32 { index, err := generateIndex(hm.l) if err != nil { - return err + return 0, err } _, inPending := hm.indexes[index] @@ -605,11 +661,11 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { if !inMain && !inPending { hh.hostinfo.localIndexId = index hm.indexes[index] = hh - return nil + return index, nil } } - return errors.New("failed to generate unique localIndexId") + return 0, errors.New("failed to generate unique localIndexId") } func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { @@ -728,3 +784,524 @@ func generateIndex(l *slog.Logger) (uint32, error) { func hsTimeout(tries int64, interval time.Duration) time.Duration { return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } + +// buildStage0Packet creates the initial handshake packet for the initiator. +func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool { + cs := hm.f.pki.getCertState() + v := cs.DefaultVersion() + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } else if v < cert.Version2 { + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + } + + cred := cs.GetCredential(v) + if cred == nil { + hm.f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, "certVersion", v) + return false + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) }, + true, header.HandshakeIXPSK0, + ) + if err != nil { + hm.f.l.Error("Failed to create handshake machine", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + msg, err := machine.Initiate(nil) + if err != nil { + hm.f.l.Error("Failed to initiate handshake", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + // hostinfo.ConnectionState stays nil until the handshake completes in + // continueHandshake. Pre-completion control surfaces guard with nil + // checks; the data plane never observes a pending hostinfo. + hh.hostinfo.HandshakePacket[handshakePacketStage0] = msg + hh.machine = machine + hh.ready = true + return true +} + +// beginHandshake handles an incoming handshake packet that doesn't match any +// existing pending handshake. It creates a new responder Machine and processes +// the first message. +func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *header.H) { + f := hm.f + cs := f.pki.getCertState() + + v := cs.DefaultVersion() + if cs.GetCredential(v) == nil { + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, "certVersion", v) + return + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) }, + false, header.HandshakeIXPSK0, + ) + if err != nil { + f.l.Error("Failed to create handshake machine", "from", via, "error", err) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + f.l.Error("Failed to process handshake packet", "from", via, "error", err) + return + } + + if result == nil { + // Multi-message pattern: the responder Machine would need to be + // registered in hm.indexes so a future inbound packet finds it via + // continueHandshake. The current manager doesn't do that yet, so + // fail loudly rather than silently dropping the in-flight handshake. + // TODO: support multi-message responder flows (XX, pqIX, etc.). + // See also the IX-shaped cipher key assignment in handshake.Machine. + f.l.Error("multi-message handshake responder is not supported", + "from", via, "error", handshake.ErrMultiMessageUnsupported) + return + } + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake did not produce a peer certificate", "from", via) + return + } + + // Validate peer identity + vpnAddrs, anyVpnAddrsInCommon, ok := hm.validatePeerCert(via, remoteCert) + if !ok { + return + } + + hostinfo := &HostInfo{ + ConnectionState: newConnectionStateFromResult(result), + localIndexId: result.LocalIndex, + remoteIndexId: result.RemoteIndex, + vpnAddrs: vpnAddrs, + HandshakePacket: make(map[uint8][]byte, 0), + lastHandshakeTime: result.HandshakeTime, + relayState: RelayState{ + relays: nil, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + } + + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.RemoteIndex, + "responderIndex", result.LocalIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + // packet aliases the listener's incoming buffer, so this copy must stay. + hostinfo.HandshakePacket[handshakePacketStage0] = make([]byte, len(packet[header.Len:])) + copy(hostinfo.HandshakePacket[handshakePacketStage0], packet[header.Len:]) + + // response was freshly allocated by ProcessPacket; safe to retain directly. + if response != nil { + hostinfo.HandshakePacket[handshakePacketStage2] = response + } + + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + existing, err := hm.CheckAndComplete(hostinfo, handshakePacketStage0, f) + if err != nil { + hm.handleCheckAndCompleteError(err, existing, hostinfo, via) + return + } + + hm.sendHandshakeResponse(via, response, hostinfo, false) + f.connectionManager.AddTrafficWatch(hostinfo) + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// continueHandshake feeds an incoming packet to an existing pending handshake Machine. +func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostInfo, packet []byte) { + f := hm.f + + hh.Lock() + defer hh.Unlock() + + // Re-verify hh is still tracked. Between queryIndex returning and us taking + // hh.Lock, handleOutbound may have timed out and deleted it. Once we hold + // hh.Lock no other deleter can race our index: handleOutbound also takes + // hh.Lock first, and handleRecvError targets a main-hostmap entry with a + // different localIndexId. + hm.RLock() + cur, ok := hm.indexes[hh.hostinfo.localIndexId] + hm.RUnlock() + if !ok || cur != hh { + return + } + + hostinfo := hh.hostinfo + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + return + } + } + + machine := hh.machine + if machine == nil { + f.l.Error("No handshake machine available for continuation", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + // Recoverable errors are routine noise, log at Debug. Fatal errors get a Warn. + if machine.Failed() { + f.l.Warn("Failed to process handshake packet, abandoning", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + hm.DeleteHostInfo(hostinfo) + } else { + f.l.Debug("Failed to process handshake packet", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + } + return + } + + if response != nil { + hm.sendHandshakeResponse(via, response, hostinfo, false) + } + + if result == nil { + return + } + + // Handshake complete; build the ConnectionState now that we have keys and a verified peer cert. + hostinfo.ConnectionState = newConnectionStateFromResult(result) + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake completed without peer certificate", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + vpnNetworks := remoteCert.Certificate.Networks() + hostinfo.remoteIndexId = result.RemoteIndex + hostinfo.lastHandshakeTime = result.HandshakeTime + + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } else { + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + } + + // Verify correct host responded (initiator check) + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + correctHostResponded := false + anyVpnAddrsInCommon := false + for i, network := range vpnNetworks { + // inside.go drops self-routed packets at the firewall stage, but we'd + // rather not let a self-handshake complete in the first place: it + // wastes a hostmap slot, suppresses no log, and obscures routing + // misconfig. Explicit refusal here mirrors the responder-side check + // in validatePeerCert. + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + hm.DeleteHostInfo(hostinfo) + return + } + vpnAddrs[i] = network.Addr() + if hostinfo.vpnAddrs[0] == network.Addr() { + correctHostResponded = true + } + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !correctHostResponded { + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + hm.DeleteHostInfo(hostinfo) + hm.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(via) + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} + hostinfo.vpnAddrs = vpnAddrs + f.sendCloseTunnel(hostinfo) + }) + return + } + + duration := time.Since(hh.startTime).Nanoseconds() + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.LocalIndex, + "responderIndex", result.RemoteIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) + + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + hm.Complete(hostinfo, f) + f.connectionManager.AddTrafficWatch(hostinfo) + + if len(hh.packetStore) > 0 { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore)) + } + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + f.metricHandshakes.Update(duration) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// validatePeerCert checks the peer certificate for self-connection and remote allow list. +// Returns the VPN addrs, whether any of them fall within one of our own VPN +// networks, and true if valid; false if rejected. +func (hm *HandshakeManager) validatePeerCert(via ViaSender, remoteCert *cert.CachedCertificate) ([]netip.Addr, bool, bool) { + f := hm.f + vpnNetworks := remoteCert.Certificate.Networks() + + // The cert package rejects host certs with no networks at parse time, so + // reaching this state would mean an invariant was bypassed elsewhere. + // Refuse explicitly so downstream code (which indexes vpnAddrs[0]) can't + // panic if that invariant ever changes. + if len(vpnNetworks) == 0 { + f.l.Info("No networks in certificate", + "from", via, "cert", remoteCert) + return nil, false, false + } + + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + anyVpnAddrsInCommon := false + + for i, network := range vpnNetworks { + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + ) + return nil, false, false + } + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, "from", via) + return nil, false, false + } + } + + return vpnAddrs, anyVpnAddrsInCommon, true +} + +// sendHandshakeResponse sends a handshake response via the appropriate transport. +// cached is true when msg is a stored response being retransmitted because +// the peer's stage-1 retransmit landed (the ErrAlreadySeen path); false on a +// fresh response. +func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hostinfo *HostInfo, cached bool) { + if msg == nil { + return + } + + f := hm.f + f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) + + // Common log fields. peerCert may be nil during intermediate + // multi-message flows (handshake hasn't completed yet); skip the cert + // block if so. + logFields := []any{ + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": uint64(2), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)}, + "cached", cached, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + } + if peerCert := hostinfo.ConnectionState.peerCert; peerCert != nil { + logFields = append(logFields, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + ) + } + + if !via.IsRelayed { + fields := append(logFields, "from", via) + err := f.outside.WriteTo(msg, via.UdpAddr) + if err != nil { + f.l.Error("Failed to send handshake message", append(fields, "error", err)...) + } else { + f.l.Info("Handshake message sent", fields...) + } + } else { + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") + return + } + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // We received a valid handshake on this relay, so make sure the relay + // state reflects that, in case it had been marked Disestablished. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...) + } +} + +// handleCheckAndCompleteError handles errors from CheckAndComplete. +// This only fires from the responder-side beginHandshake path, after the +// peer cert has been validated and ConnectionState populated, so peerCert +// is always non-nil for the cases that log it. +func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hostinfo *HostInfo, via ViaSender) { + f := hm.f + peerCert := hostinfo.ConnectionState.peerCert + hsFields := m{"stage": uint64(1), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)} + + switch err { + case ErrAlreadySeen: + if existing.SetRemoteIfPreferred(f.hostMap, via) { + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + } + // Resend the original response. The peer is committed to that response's + // ephemeral keys; a freshly-built one would have different keys and break + // the tunnel even though both sides "completed" the handshake. + if msg := existing.HandshakePacket[handshakePacketStage2]; msg != nil { + hm.sendHandshakeResponse(via, msg, existing, true) + } + + case ErrExistingHostInfo: + f.l.Info("Handshake too old", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + + case ErrLocalIndexCollision: + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "localIndex", hostinfo.localIndexId, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + + default: + f.l.Error("Failed to add HostInfo to HostMap", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "error", err, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + } +} + +// certVerifier returns a CertVerifier that validates certs against the current CA pool. +func (hm *HandshakeManager) certVerifier() handshake.CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return hm.f.pki.GetCAPool().VerifyCertificate(time.Now(), c) + } +} diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b5..5f8383e4 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { func (mw *mockEncWriter) GetCertState() *CertState { return &CertState{initiatingVersion: cert.Version2} } + +func TestValidatePeerCert(t *testing.T) { + l := test.NewLogger() + + myNetwork := netip.MustParsePrefix("10.0.0.1/24") + myAddrTable := new(bart.Lite) + myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen())) + myNetTable := new(bart.Lite) + myNetTable.Insert(myNetwork.Masked()) + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + myVpnAddrsTable: myAddrTable, + myVpnNetworksTable: myNetTable, + lightHouse: hm.lightHouse, + } + return hm + } + + cached := func(networks ...netip.Prefix) *cert.CachedCertificate { + return &cert.CachedCertificate{ + Certificate: &dummyCert{name: "peer", networks: networks}, + } + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // skip the remote allow list (covered separately) + } + + t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) { + hm := newHM() + // 10.0.0.2 falls inside our 10.0.0.0/24 + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24"))) + assert.True(t, ok) + assert.True(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs) + }) + + t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24"))) + assert.True(t, ok) + assert.False(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs) + }) + + t.Run("any matching network is enough", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached( + netip.MustParsePrefix("192.168.1.5/24"), + netip.MustParsePrefix("10.0.0.42/24"), + )) + assert.True(t, ok) + assert.True(t, common) + assert.Len(t, addrs, 2) + }) + + t.Run("self-handshake is rejected", func(t *testing.T) { + hm := newHM() + // 10.0.0.1 is in myVpnAddrsTable + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24"))) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) + + t.Run("cert with no networks is rejected", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached()) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) +} + +func TestHandleIncomingDispatch(t *testing.T) { + l := test.NewLogger() + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + } + return hm + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // bypass remote allow list + } + + // A packet body of zero length is fine for these tests: dispatch is + // gated on header fields, and we assert that we never reach noise/cert + // processing for any of the malformed shapes here. + pkt := make([]byte, header.Len) + + t.Run("unsupported subtype dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1} + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "no pending handshake should be created") + }) + + t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xdeadbeef, + MessageCounter: 1, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine") + }) + + t.Run("continuation with no matching pending index dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xcafef00d, + MessageCounter: 2, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "orphan stage-2 must not create state") + }) +} diff --git a/nebula.pb.go b/nebula.pb.go index 2fd2ff66..94a4ebe2 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8, 0} + return fileDescriptor_2d65afa7693df5ef, []int{6, 0} } type NebulaMeta struct { @@ -489,142 +489,6 @@ func (m *NebulaPing) GetTime() uint64 { return 0 } -type NebulaHandshake struct { - Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"` - Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"` -} - -func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } -func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshake) ProtoMessage() {} -func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} -} -func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshake) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshake.Merge(m, src) -} -func (m *NebulaHandshake) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshake) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshake.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo - -func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails { - if m != nil { - return m.Details - } - return nil -} - -func (m *NebulaHandshake) GetHmac() []byte { - if m != nil { - return m.Hmac - } - return nil -} - -type NebulaHandshakeDetails struct { - Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"` - InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"` - ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` - Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` - Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` - CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` -} - -func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } -func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshakeDetails) ProtoMessage() {} -func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} -} -func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src) -} -func (m *NebulaHandshakeDetails) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo - -func (m *NebulaHandshakeDetails) GetCert() []byte { - if m != nil { - return m.Cert - } - return nil -} - -func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 { - if m != nil { - return m.InitiatorIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 { - if m != nil { - return m.ResponderIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCookie() uint64 { - if m != nil { - return m.Cookie - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetTime() uint64 { - if m != nil { - return m.Time - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { - if m != nil { - return m.CertVersion - } - return 0 -} - type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` @@ -639,7 +503,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -729,65 +593,55 @@ func init() { proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") - proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") - proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl") } func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 785 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, - 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, - 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, - 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, - 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, - 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, - 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, - 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, - 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, - 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, - 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, - 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, - 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, - 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, - 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, - 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, - 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, - 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, - 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, - 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, - 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, - 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, - 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, - 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, - 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, - 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, - 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, - 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, - 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, - 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, - 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, - 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, - 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, - 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, - 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, - 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, - 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, - 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, - 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, - 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, - 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, - 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, - 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, - 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, - 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, - 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, - 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, - 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, - 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, - 0x00, + // 665 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x54, 0xcd, 0x6e, 0xd3, 0x5c, + 0x10, 0x8d, 0x1d, 0x27, 0x69, 0x27, 0x4d, 0x3e, 0x7f, 0x53, 0x51, 0x12, 0x24, 0xac, 0xe0, 0x45, + 0x55, 0xb1, 0x48, 0x51, 0x5a, 0xba, 0xa6, 0x2d, 0x42, 0xa9, 0xd4, 0x9f, 0x70, 0x55, 0x8a, 0xc4, + 0xce, 0xb5, 0x2f, 0x8d, 0x55, 0xc7, 0x37, 0xb5, 0x6f, 0x50, 0xf3, 0x16, 0x3c, 0x0c, 0x0f, 0x01, + 0xbb, 0x2e, 0x59, 0xa2, 0x66, 0xc9, 0x92, 0x17, 0x40, 0xf7, 0xfa, 0xbf, 0x31, 0xb0, 0xbb, 0x33, + 0xe7, 0x9c, 0x99, 0xc9, 0xc9, 0x8c, 0x61, 0xcd, 0xa7, 0x97, 0x33, 0xcf, 0xea, 0x4f, 0x03, 0xc6, + 0x19, 0xd6, 0xa3, 0xc8, 0xfc, 0xa9, 0x02, 0x9c, 0xca, 0xe7, 0x09, 0xe5, 0x16, 0x0e, 0x40, 0x3b, + 0x9f, 0x4f, 0x69, 0x47, 0xe9, 0x29, 0x5b, 0xed, 0x81, 0xd1, 0x8f, 0x35, 0x19, 0xa3, 0x7f, 0x42, + 0xc3, 0xd0, 0xba, 0xa2, 0x82, 0x45, 0x24, 0x17, 0x77, 0xa0, 0xf1, 0x9a, 0x72, 0xcb, 0xf5, 0xc2, + 0x8e, 0xda, 0x53, 0xb6, 0x9a, 0x83, 0xee, 0xb2, 0x2c, 0x26, 0x90, 0x84, 0x69, 0xfe, 0x52, 0xa0, + 0x99, 0x2b, 0x85, 0x2b, 0xa0, 0x9d, 0x32, 0x9f, 0xea, 0x15, 0x6c, 0xc1, 0xea, 0x90, 0x85, 0xfc, + 0xed, 0x8c, 0x06, 0x73, 0x5d, 0x41, 0x84, 0x76, 0x1a, 0x12, 0x3a, 0xf5, 0xe6, 0xba, 0x8a, 0x4f, + 0x60, 0x43, 0xe4, 0xde, 0x4d, 0x1d, 0x8b, 0xd3, 0x53, 0xc6, 0xdd, 0x8f, 0xae, 0x6d, 0x71, 0x97, + 0xf9, 0x7a, 0x15, 0xbb, 0xf0, 0x48, 0x60, 0x27, 0xec, 0x13, 0x75, 0x0a, 0x90, 0x96, 0x40, 0xa3, + 0x99, 0x6f, 0x8f, 0x0b, 0x50, 0x0d, 0xdb, 0x00, 0x02, 0x7a, 0x3f, 0x66, 0xd6, 0xc4, 0xd5, 0xeb, + 0xb8, 0x0e, 0xff, 0x65, 0x71, 0xd4, 0xb6, 0x21, 0x26, 0x1b, 0x59, 0x7c, 0x7c, 0x38, 0xa6, 0xf6, + 0xb5, 0xbe, 0x22, 0x26, 0x4b, 0xc3, 0x88, 0xb2, 0x8a, 0x4f, 0xa1, 0x5b, 0x3e, 0xd9, 0xbe, 0x7d, + 0xad, 0x83, 0xf9, 0x4d, 0x85, 0xff, 0x97, 0x4c, 0x41, 0x13, 0xe0, 0xcc, 0x73, 0x2e, 0xa6, 0xfe, + 0xbe, 0xe3, 0x04, 0xd2, 0xfa, 0xd6, 0x81, 0xda, 0x51, 0x48, 0x2e, 0x8b, 0x9b, 0xd0, 0x48, 0x08, + 0x75, 0x69, 0xf2, 0x5a, 0x62, 0xb2, 0xc8, 0x91, 0x04, 0xc4, 0x3e, 0xe8, 0x67, 0x9e, 0x43, 0xa8, + 0x67, 0xcd, 0xe3, 0x54, 0xd8, 0xa9, 0xf5, 0xaa, 0x71, 0xc5, 0x25, 0x0c, 0x07, 0xd0, 0x2a, 0x92, + 0x1b, 0xbd, 0xea, 0x52, 0xf5, 0x22, 0x05, 0x77, 0xa1, 0x79, 0xb1, 0x2b, 0x9e, 0x23, 0x16, 0x70, + 0xf1, 0xa7, 0x0b, 0x05, 0x26, 0x8a, 0x0c, 0x22, 0x79, 0x9a, 0x54, 0xed, 0x65, 0x2a, 0xed, 0x81, + 0x6a, 0x2f, 0xa7, 0xca, 0x68, 0xd8, 0x81, 0x86, 0xcd, 0x66, 0x3e, 0xa7, 0x41, 0xa7, 0x2a, 0x8c, + 0x21, 0x49, 0x68, 0x6e, 0x82, 0x26, 0x7f, 0x71, 0x1b, 0xd4, 0xa1, 0x2b, 0x5d, 0xd3, 0x88, 0x3a, + 0x74, 0x45, 0x7c, 0xcc, 0xe4, 0x26, 0x6a, 0x44, 0x3d, 0x66, 0xe6, 0x2e, 0x40, 0x36, 0x06, 0x62, + 0xa4, 0x8a, 0x5c, 0x26, 0x51, 0x05, 0x04, 0x4d, 0x60, 0x52, 0xd3, 0x22, 0xf2, 0x6d, 0xbe, 0x02, + 0xc8, 0xc6, 0xf8, 0x57, 0x8f, 0xb4, 0x42, 0x35, 0x57, 0xe1, 0x36, 0x39, 0xac, 0x91, 0xeb, 0x5f, + 0xfd, 0xfd, 0xb0, 0x04, 0xa3, 0xe4, 0xb0, 0x10, 0xb4, 0x73, 0x77, 0x42, 0xe3, 0x3e, 0xf2, 0x6d, + 0x9a, 0x4b, 0x67, 0x23, 0xc4, 0x7a, 0x05, 0x57, 0xa1, 0x16, 0x2d, 0xa1, 0x62, 0x7e, 0xa9, 0x42, + 0x2b, 0x2a, 0x7c, 0xc8, 0x7c, 0x1e, 0x30, 0x0f, 0x5f, 0x16, 0xba, 0x3f, 0x2b, 0x76, 0x8f, 0x49, + 0x25, 0x03, 0xbc, 0x80, 0xf5, 0x23, 0xdf, 0xe5, 0xae, 0xc5, 0x59, 0x20, 0x57, 0xe0, 0xc8, 0x77, + 0xe8, 0x6d, 0xec, 0x53, 0x19, 0x24, 0x14, 0x84, 0x86, 0x53, 0xe6, 0x3b, 0x34, 0xaf, 0x88, 0x7c, + 0x29, 0x83, 0xf0, 0x39, 0xb4, 0x93, 0xa5, 0x3c, 0x67, 0xf2, 0xaf, 0xd1, 0xd2, 0x03, 0x78, 0x80, + 0xe4, 0x97, 0xfb, 0x4d, 0xc0, 0x26, 0x92, 0x5d, 0x4b, 0xd9, 0x4b, 0x18, 0xf6, 0xa1, 0x99, 0x2f, + 0x5c, 0x76, 0x38, 0x79, 0x42, 0x7a, 0x0c, 0x69, 0xf1, 0x46, 0x89, 0xa2, 0x48, 0x31, 0x87, 0x7f, + 0xfa, 0x8e, 0x6d, 0x00, 0x1e, 0x06, 0xd4, 0xe2, 0x54, 0xf2, 0x09, 0xbd, 0x99, 0xd1, 0x90, 0xeb, + 0x0a, 0x3e, 0x86, 0xf5, 0x42, 0x5e, 0x58, 0x12, 0x52, 0x5d, 0x3d, 0xd8, 0xf9, 0x7a, 0x6f, 0x28, + 0x77, 0xf7, 0x86, 0xf2, 0xe3, 0xde, 0x50, 0x3e, 0x2f, 0x8c, 0xca, 0xdd, 0xc2, 0xa8, 0x7c, 0x5f, + 0x18, 0x95, 0x0f, 0xdd, 0x2b, 0x97, 0x8f, 0x67, 0x97, 0x7d, 0x9b, 0x4d, 0xb6, 0x43, 0xcf, 0xb2, + 0xaf, 0xc7, 0x37, 0xdb, 0xd1, 0x48, 0x97, 0x75, 0xf9, 0x39, 0xdf, 0xf9, 0x1d, 0x00, 0x00, 0xff, + 0xff, 0x51, 0x0a, 0xe3, 0xd7, 0xde, 0x05, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -1072,103 +926,6 @@ func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.Hmac) > 0 { - i -= len(m.Hmac) - copy(dAtA[i:], m.Hmac) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac))) - i-- - dAtA[i] = 0x12 - } - if m.Details != nil { - { - size, err := m.Details.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintNebula(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if m.CertVersion != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) - i-- - dAtA[i] = 0x40 - } - if m.Time != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Time)) - i-- - dAtA[i] = 0x28 - } - if m.Cookie != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Cookie)) - i-- - dAtA[i] = 0x20 - } - if m.ResponderIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex)) - i-- - dAtA[i] = 0x18 - } - if m.InitiatorIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex)) - i-- - dAtA[i] = 0x10 - } - if len(m.Cert) > 0 { - i -= len(m.Cert) - copy(dAtA[i:], m.Cert) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - func (m *NebulaControl) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -1375,51 +1132,6 @@ func (m *NebulaPing) Size() (n int) { return n } -func (m *NebulaHandshake) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.Details != nil { - l = m.Details.Size() - n += 1 + l + sovNebula(uint64(l)) - } - l = len(m.Hmac) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - return n -} - -func (m *NebulaHandshakeDetails) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - l = len(m.Cert) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - if m.InitiatorIndex != 0 { - n += 1 + sovNebula(uint64(m.InitiatorIndex)) - } - if m.ResponderIndex != 0 { - n += 1 + sovNebula(uint64(m.ResponderIndex)) - } - if m.Cookie != 0 { - n += 1 + sovNebula(uint64(m.Cookie)) - } - if m.Time != 0 { - n += 1 + sovNebula(uint64(m.Time)) - } - if m.CertVersion != 0 { - n += 1 + sovNebula(uint64(m.CertVersion)) - } - return n -} - func (m *NebulaControl) Size() (n int) { if m == nil { return 0 @@ -2236,305 +1948,6 @@ func (m *NebulaPing) Unmarshal(dAtA []byte) error { } return nil } -func (m *NebulaHandshake) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.Details == nil { - m.Details = &NebulaHandshakeDetails{} - } - if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...) - if m.Hmac == nil { - m.Hmac = []byte{} - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...) - if m.Cert == nil { - m.Cert = []byte{} - } - iNdEx = postIndex - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType) - } - m.InitiatorIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.InitiatorIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType) - } - m.ResponderIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.ResponderIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType) - } - m.Cookie = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Cookie |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 5: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) - } - m.Time = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Time |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 8: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) - } - m.CertVersion = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.CertVersion |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} func (m *NebulaControl) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/nebula.proto b/nebula.proto index ea102334..7b44f473 100644 --- a/nebula.proto +++ b/nebula.proto @@ -60,21 +60,9 @@ message NebulaPing { uint64 Time = 2; } -message NebulaHandshake { - NebulaHandshakeDetails Details = 1; - bytes Hmac = 2; -} - -message NebulaHandshakeDetails { - bytes Cert = 1; - uint32 InitiatorIndex = 2; - uint32 ResponderIndex = 3; - uint64 Cookie = 4; - uint64 Time = 5; - uint32 CertVersion = 8; - // reserved for WIP multiport - reserved 6, 7; -} +// NebulaHandshake / NebulaHandshakeDetails moved to +// handshake/handshake.proto. The handshake package speaks that wire format +// directly via a hand-written encoder/decoder. message NebulaControl { enum MessageType { diff --git a/pki.go b/pki.go index fb8cc5c6..acc80486 100644 --- a/pki.go +++ b/pki.go @@ -15,9 +15,12 @@ import ( "sync/atomic" "time" + "github.com/flynn/noise" "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/util" ) @@ -28,11 +31,11 @@ type PKI struct { } type CertState struct { - v1Cert cert.Certificate - v1HandshakeBytes []byte + v1Cert cert.Certificate + v1Credential *handshake.Credential - v2Cert cert.Certificate - v2HandshakeBytes []byte + v2Cert cert.Certificate + v2Credential *handshake.Credential initiatingVersion cert.Version privateKey []byte @@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + var cipher string + var currentState *CertState + if initial { + cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: + return util.NewContextualError( + "unknown cipher", + m{"cipher": cipher}, + nil, + ) + } + } else { + // Cipher cant be hot swapped so just leave it at what it was before + currentState = p.cs.Load() + cipher = currentState.cipher + } + + newState, err := newCertStateFromConfig(c, cipher) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } - if !initial { - currentState := p.cs.Load() + if currentState != nil { if newState.v1Cert != nil { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). @@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { ) } } - - // Cipher cant be hot swapped so just leave it at what it was before - newState.cipher = currentState.cipher - - } else { - newState.cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global - switch newState.cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return util.NewContextualError( - "unknown cipher", - m{"cipher": newState.cipher}, - nil, - ) - } } p.cs.Store(newState) @@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate { return c } +// DefaultVersion returns the preferred cert version for initiating handshakes. +func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion } + +// GetCredential returns the pre-computed handshake credential for the given version, or nil. +func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential { + switch v { + case cert.Version1: + return cs.v1Credential + case cert.Version2: + return cs.v2Credential + } + return nil +} + func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { switch v { case cert.Version1: @@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { return nil } -// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. -// Callers must check if the return []byte is nil. -func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { - switch v { - case cert.Version1: - return cs.v1HandshakeBytes - case cert.Version2: - return cs.v2HandshakeBytes +func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) { + var dhFunc noise.DHFunc + switch curve { + case cert.Curve_CURVE25519: + dhFunc = noise.DH25519 + case cert.Curve_P256: + if pkcs11backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: - return nil + return nil, fmt.Errorf("unsupported curve: %s", curve) } + + if cipher == "chachapoly" { + return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil + } + return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil } func (cs *CertState) String() string { @@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, + cipher: cipher, myVpnNetworksTable: new(bart.Lite), myVpnAddrsTable: new(bart.Lite), myVpnBroadcastAddrsTable: new(bart.Lite), @@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v1hs, err := v1.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v1Cert = v1 - cs.v1HandshakeBytes = v1hs + cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version1 @@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v2hs, err := v2.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v2Cert = v2 - cs.v2HandshakeBytes = v2hs + cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version2 From f141cebe8d52fb9c30af93208f36e2a59e7287e5 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 30 Apr 2026 21:30:56 -0500 Subject: [PATCH 40/44] Run e2e tests in parallel, include a goroutine leak detector test (#1700) --- e2e/handshake_manager_test.go | 12 +++++++++ e2e/handshakes_test.go | 21 +++++++++++++++ e2e/leak_test.go | 51 +++++++++++++++++++++++++++++++++++ e2e/tunnels_test.go | 6 +++++ go.mod | 1 + 5 files changed, 91 insertions(+) create mode 100644 e2e/leak_test.go diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go index 3fe784c1..1c6ebacc 100644 --- a/e2e/handshake_manager_test.go +++ b/e2e/handshake_manager_test.go @@ -28,6 +28,7 @@ func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, } func TestHandshakeRetransmitDuplicate(t *testing.T) { + t.Parallel() // Verify the responder correctly handles receiving the same msg1 multiple times // (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen // and the cached response is resent. @@ -78,6 +79,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) { } func TestHandshakeTruncatedPacketRecovery(t *testing.T) { + t.Parallel() // Verify that a truncated handshake packet is ignored and the real // packet can still complete the handshake. @@ -126,6 +128,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) { } func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { + t.Parallel() // A msg2 arriving with no matching pending index should be silently dropped // with no response sent and no state changes. @@ -168,6 +171,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { } func TestHandshakeUnknownMessageCounter(t *testing.T) { + t.Parallel() // A handshake packet with an unexpected message counter should be silently // dropped with no side effects and no UDP response. @@ -199,6 +203,7 @@ func TestHandshakeUnknownMessageCounter(t *testing.T) { } func TestHandshakeUnknownSubtype(t *testing.T) { + t.Parallel() // A handshake packet with an unknown subtype should be silently dropped. ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -224,6 +229,7 @@ func TestHandshakeUnknownSubtype(t *testing.T) { } func TestHandshakeLateResponse(t *testing.T) { + t.Parallel() // After a handshake times out, a late response should be silently ignored // with no new tunnels created. @@ -273,6 +279,7 @@ func TestHandshakeLateResponse(t *testing.T) { } func TestHandshakeSelfConnectionRejected(t *testing.T) { + t.Parallel() // Verify that a node rejects a handshake containing its own VPN IP in the // peer cert. We do this by sending the initiator's own msg1 back to itself. @@ -321,6 +328,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) { } func TestHandshakeMessageCounter0Dropped(t *testing.T) { + t.Parallel() // MessageCounter=0 is not a valid handshake message and should be dropped. ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -341,6 +349,7 @@ func TestHandshakeMessageCounter0Dropped(t *testing.T) { } func TestHandshakeRemoteAllowList(t *testing.T) { + t.Parallel() // Verify that a handshake from a blocked underlay IP is dropped with no // response and no state changes. Then verify the same packet from an // allowed IP succeeds. @@ -399,6 +408,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) { } func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { + t.Parallel() // When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel // remains functional and hostmap index count is stable. @@ -445,6 +455,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { } func TestHandshakeWrongResponderPacketStore(t *testing.T) { + t.Parallel() // Verify that when the wrong host responds, the cached packets are // transferred to the new handshake, the evil tunnel is closed, evil's // address is blocked, and the correct tunnel is eventually established. @@ -508,6 +519,7 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) { } func TestHandshakeRelayComplete(t *testing.T) { + t.Parallel() // Verify that a relay handshake completes correctly and relay state is // properly maintained on all three nodes. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 93f200ac..43fa72f2 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -84,6 +84,7 @@ func BenchmarkHotPathRelay(b *testing.B) { } func TestGoodHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -134,6 +135,7 @@ func TestGoodHandshake(t *testing.T) { } func TestGoodHandshakeNoOverlap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! @@ -169,6 +171,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) @@ -245,6 +248,7 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) @@ -327,6 +331,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { } func TestStage1Race(t *testing.T) { + t.Parallel() // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel @@ -407,6 +412,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -456,6 +462,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -507,6 +514,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -536,6 +544,7 @@ func TestRelays(t *testing.T) { } func TestRelaysDontCareAboutIps(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) @@ -565,6 +574,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { } func TestReestablishRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -696,6 +706,7 @@ func TestReestablishRelays(t *testing.T) { } func TestStage1RaceRelays(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -743,6 +754,7 @@ func TestStage1RaceRelays(t *testing.T) { } func TestStage1RaceRelays2(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -819,6 +831,7 @@ func TestStage1RaceRelays2(t *testing.T) { } func TestRehandshakingRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -922,6 +935,7 @@ func TestRehandshakingRelays(t *testing.T) { } func TestRehandshakingRelaysPrimary(t *testing.T) { + t.Parallel() // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) @@ -1026,6 +1040,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) @@ -1121,6 +1136,7 @@ func TestRehandshaking(t *testing.T) { } func TestRehandshakingLoser(t *testing.T) { + t.Parallel() // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -1219,6 +1235,7 @@ func TestRehandshakingLoser(t *testing.T) { } func TestRaceRegression(t *testing.T) { + t.Parallel() // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo @@ -1279,6 +1296,7 @@ func TestRaceRegression(t *testing.T) { } func TestV2NonPrimaryWithLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1319,6 +1337,7 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { } func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1359,6 +1378,7 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { } func TestLighthouseUpdateOnReload(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // Create the lighthouse @@ -1434,6 +1454,7 @@ func TestLighthouseUpdateOnReload(t *testing.T) { } func TestGoodHandshakeUnsafeDest(t *testing.T) { + t.Parallel() unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) diff --git a/e2e/leak_test.go b/e2e/leak_test.go new file mode 100644 index 00000000..ffb024fe --- /dev/null +++ b/e2e/leak_test.go @@ -0,0 +1,51 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "go.uber.org/goleak" +) + +// TestNoGoroutineLeaks brings up two nebula instances, completes a tunnel, +// stops both, and asserts no goroutines leak past the shutdown. goleak's +// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain +// before failing the assertion. +// +// IgnoreCurrent is necessary in the parallelized suite: other tests can +// leave goroutines mid-shutdown when this one runs (Stop is async, the +// wg.Wait() drain is not blocking on test return). We're checking that +// *this* test's setup tears down cleanly, not that the whole suite is +// idle at this moment. Intentionally NOT t.Parallel()'d for the same +// reason — concurrent test goroutines would always show up. +func TestNoGoroutineLeaks(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + r.RenderFlow() + + // Settle period: Stop() is non-blocking; the wg-driven goroutines need + // a moment to drain. goleak retries internally too, but a short explicit + // settle reduces flakes when the suite is busy. + time.Sleep(50 * time.Millisecond) +} diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index e8e41945..63c655f3 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -19,6 +19,7 @@ import ( ) func TestDropInactiveTunnels(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -63,6 +64,7 @@ func TestDropInactiveTunnels(t *testing.T) { } func TestCertUpgrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -157,6 +159,7 @@ func TestCertUpgrade(t *testing.T) { } func TestCertDowngrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -255,6 +258,7 @@ func TestCertDowngrade(t *testing.T) { } func TestCertMismatchCorrection(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -322,6 +326,7 @@ func TestCertMismatchCorrection(t *testing.T) { } func TestCrossStackRelaysWork(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) @@ -369,6 +374,7 @@ func TestCrossStackRelaysWork(t *testing.T) { } func TestCloseTunnelAuthenticated(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) diff --git a/go.mod b/go.mod index 0de2df7d..24d901c5 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 + go.uber.org/goleak v1.3.0 go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 From 33c2d7277c3a6f43b3ef63dc87f7e4754722ca29 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 1 May 2026 13:21:38 -0500 Subject: [PATCH 41/44] Reduce HandshakeManager complexity a little bit (#1701) --- handshake_manager.go | 144 +------------------------------------ main.go | 10 +-- relay_manager.go | 166 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 163 insertions(+), 157 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index 9fc69ff4..87257028 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -23,7 +23,6 @@ const ( DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 - DefaultUseRelays = true // maxCachedPackets is how many unsent packets we'll buffer per pending // handshake before dropping further ones. @@ -43,7 +42,6 @@ var ( tryInterval: DefaultHandshakeTryInterval, retries: DefaultHandshakeRetries, triggerBuffer: DefaultHandshakeTriggerBuffer, - useRelays: DefaultUseRelays, } ) @@ -51,7 +49,6 @@ type HandshakeConfig struct { tryInterval time.Duration retries int64 triggerBuffer int - useRelays bool messageMetrics *MessageMetrics } @@ -326,146 +323,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered ) } - if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) - // Send a RelayRequest to all known Relay IP's - for _, relay := range hostinfo.remotes.relays { - // Don't relay through the host I'm trying to connect to - if relay == vpnIp { - continue - } - - // Don't relay to myself - if hm.f.myVpnAddrsTable.Contains(relay) { - continue - } - - relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) - if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String()) - hm.f.Handshake(relay) - continue - } - // Check the relay HostInfo to see if we already established a relay through - existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) - if !ok { - // No relays exist or requested yet. - if relayHostInfo.remote.IsValid() { - idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) - if err != nil { - hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) - } - - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: idx, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) - } else { - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.Info("send CreateRelayRequest", - "relayFrom", hm.f.myVpnAddrs[0], - "relayTo", vpnIp, - "initiatorRelayIndex", idx, - "relay", relay, - ) - } - } - continue - } - - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false) - case Disestablished: - // Mark this relay as 'requested' - relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) - fallthrough - case Requested: - hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String()) - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.Info("send CreateRelayRequest", - "relayFrom", hm.f.myVpnAddrs[0], - "relayTo", vpnIp, - "initiatorRelayIndex", existingRelay.LocalIndex, - "relay", relay, - ) - } - case PeerRequested: - // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. - fallthrough - default: - hostinfo.logger(hm.l).Error("Relay unexpected state", - "vpnIp", vpnIp, - "state", existingRelay.State, - "relay", relay, - ) - - } - } - } + hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0) // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { diff --git a/main.go b/main.go index eef13c97..d5e5dcc8 100644 --- a/main.go +++ b/main.go @@ -184,14 +184,10 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev messageMetrics = newMessageMetricsOnlyRecvError() } - useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) - handshakeConfig := HandshakeConfig{ - tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), - triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), - useRelays: useRelays, - + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, } diff --git a/relay_manager.go b/relay_manager.go index 919bb2b6..25e65871 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -15,9 +15,10 @@ import ( ) type relayManager struct { - l *slog.Logger - hostmap *HostMap - amRelay atomic.Bool + l *slog.Logger + hostmap *HostMap + amRelay atomic.Bool + useRelays atomic.Bool } func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { @@ -36,8 +37,10 @@ func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *c } func (rm *relayManager) reload(c *config.C, initial bool) error { - if initial || c.HasChanged("relay.am_relay") { - rm.setAmRelay(c.GetBool("relay.am_relay", false)) + if initial || c.HasChanged("relay.am_relay") || c.HasChanged("relay.use_relays") { + amRelay := c.GetBool("relay.am_relay", false) + rm.amRelay.Store(amRelay) + rm.useRelays.Store(c.GetBool("relay.use_relays", true) && !amRelay) } return nil } @@ -46,8 +49,157 @@ func (rm *relayManager) GetAmRelay() bool { return rm.amRelay.Load() } -func (rm *relayManager) setAmRelay(v bool) { - rm.amRelay.Store(v) +func (rm *relayManager) GetUseRelays() bool { + return rm.useRelays.Load() +} + +// StartRelays drives the relay-establishment side of an outbound handshake attempt. +// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits +// one that may have been lost, or, once the relay is Established, forwards the in-progress +// stage 0 handshake packet for vpnIp through it. +func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) { + if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 { + return + } + + hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + // Send a RelayRequest to all known Relay IP's + for _, relay := range hostinfo.remotes.relays { + // Don't relay through the host I'm trying to connect to + if relay == vpnIp { + continue + } + + // Don't relay to myself + if f.myVpnAddrsTable.Contains(relay) { + continue + } + + relayHostInfo := rm.hostmap.QueryVpnAddr(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { + hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String()) + f.Handshake(relay) + continue + } + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { + // No relays exist or requested yet. + if relayHostInfo.remote.IsValid() { + idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested) + if err != nil { + hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) + } + + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: idx, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) + } + } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) + f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String()) + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(rm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) + + } + } } // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. From b7e9939e921aab000699115330fb31f33c6449b9 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 4 May 2026 10:12:58 -0500 Subject: [PATCH 42/44] More stable e2e test harness, better for benchmarking (#1702) --- control_tester.go | 72 ++-------- e2e/handshake_manager_test.go | 24 ++-- e2e/handshakes_test.go | 76 ++++++----- e2e/helpers_test.go | 59 ++++++++- e2e/router/router.go | 242 ++++++++++++++++++++++++++-------- e2e/tunnels_test.go | 4 +- overlay/tun_tester.go | 54 +++++++- udp/udp_tester.go | 67 +++++++--- 8 files changed, 418 insertions(+), 180 deletions(-) diff --git a/control_tester.go b/control_tester.go index f927140b..728ac649 100644 --- a/control_tester.go +++ b/control_tester.go @@ -5,8 +5,6 @@ package nebula import ( "net/netip" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -22,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message panic(err) } pipeTo.InjectUDPPacket(p) - if h.Type == msgType && h.Subtype == subType { + match := h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -38,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, panic(err) } pipeTo.InjectUDPPacket(p) - if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType { + match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -90,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte { return c.f.inside.(*overlay.TestTun).TxPackets } -// InjectUDPPacket will inject a packet into the udp side of nebula +// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p. +// The copy comes from the freelist so steady-state alloc is zero. func (c *Control) InjectUDPPacket(p *udp.Packet) { - c.f.outside.(*udp.TesterConn).Send(p) + c.f.outside.(*udp.TesterConn).Send(p.Copy()) } -// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { - serialize := make([]gopacket.SerializableLayer, 0) - var netLayer gopacket.NetworkLayer - if toAddr.Is6() { - if !fromAddr.Is6() { - panic("Cant send ipv6 to ipv4") - } - ip := &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } else { - if !fromAddr.Is4() { - panic("Cant send ipv4 to ipv6") - } - - ip := &layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } - - udp := layers.UDP{ - SrcPort: layers.UDPPort(fromPort), - DstPort: layers.UDPPort(toPort), - } - err := udp.SetNetworkLayerForChecksum(netLayer) - if err != nil { - panic(err) - } - - buffer := gopacket.NewSerializeBuffer() - opt := gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } - - serialize = append(serialize, &udp, gopacket.Payload(data)) - err = gopacket.SerializeLayers(buffer, opt, serialize...) - if err != nil { - panic(err) - } - - c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) +// InjectTunPacket pushes an IP packet onto the tun interface. +func (c *Control) InjectTunPacket(packet []byte) { + c.f.inside.(*overlay.TestTun).Send(packet) } func (c *Control) GetVpnAddrs() []netip.Addr { diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go index 1c6ebacc..b06564d1 100644 --- a/e2e/handshake_manager_test.go +++ b/e2e/handshake_manager_test.go @@ -47,7 +47,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Grab my msg1") msg1 := myControl.GetFromUDP(true) @@ -97,7 +97,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Get msg1 and deliver to responder") msg1 := myControl.GetFromUDP(true) @@ -146,7 +146,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { defer r.RenderFlow() t.Log("Complete a normal handshake") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) r.RouteForAllUntilTxTun(theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) @@ -248,7 +248,7 @@ func TestHandshakeLateResponse(t *testing.T) { theirControl.Start() t.Log("Trigger handshake from me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Grab msg1 but don't deliver") msg1 := myControl.GetFromUDP(true) @@ -292,7 +292,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) { myControl.Start() t.Log("Trigger handshake from me") - myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) msg1 := myControl.GetFromUDP(true) t.Log("Drain any handshake retransmits before injecting") @@ -375,7 +375,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake from them") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))) msg1 := theirControl.GetFromUDP(true) t.Log("Rewrite the source to a blocked IP and inject") @@ -426,7 +426,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { defer r.RenderFlow() t.Log("Complete a normal handshake via the router") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) r.RouteForAllUntilTxTun(theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) @@ -437,7 +437,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { originalRemote := hi.CurrentRemote t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))) r.RouteForAllUntilTxTun(theirControl) t.Log("Verify tunnel still works") @@ -475,8 +475,8 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) { evilControl.Start() t.Log("Send multiple packets to them (cached during handshake)") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))) t.Log("Route until evil tunnel is closed") h := &header.H{} @@ -540,7 +540,7 @@ func TestHandshakeRelayComplete(t *testing.T) { theirControl.Start() t.Log("Trigger handshake via relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -568,7 +568,7 @@ func TestHandshakeRelayComplete(t *testing.T) { } // NOTE: Relay V1 cert + IPv6 rejection is not tested here because -// InjectTunUDPPacket from a V4 node to a V6 address panics in the test +// BuildTunUDPPacket from a V4 node to a V6 address panics in the test // framework. The check is in handshake_manager.go handleOutbound relay // logic (lines ~304-313): if the relay host has a V1 cert and either // address is IPv6, the relay is skipped. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 43fa72f2..d0b9543c 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -16,6 +16,7 @@ import ( "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + // Pre-build the IP packet bytes once so the bench measures the data plane, + // not gopacket SerializeLayers overhead. + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + // EnableFanIn switches the router to a 0-alloc routing path. Required + // for hot-path benchmarks; would conflict with GetFromUDP-using tests. + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + // Release the TUN-side bytes back to the harness freelist; the bench + // just confirms a packet arrived, the contents aren't inspected. + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -71,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) { theirControl.Start() assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -191,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -273,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -352,8 +368,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -430,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -441,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))) p = r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -480,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -492,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) @@ -535,7 +551,7 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -565,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -595,14 +611,14 @@ func TestReestablishRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Ensure packet traversal from them to me via the relay") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -617,7 +633,7 @@ func TestReestablishRelays(t *testing.T) { for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { return router.RouteAndExit @@ -634,7 +650,7 @@ func TestReestablishRelays(t *testing.T) { myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p = r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -669,7 +685,7 @@ func TestReestablishRelays(t *testing.T) { t.Log("Assert the tunnel works the other way, too") for { t.Log("RouteForAllUntilTxTun") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -739,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -787,8 +803,8 @@ func TestStage1RaceRelays2(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -852,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -957,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -1259,8 +1275,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -1476,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -1504,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) //reply - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))) //wait for reply theirControl.WaitForType(1, 0, myControl) theirCachedPacket := myControl.GetFromTun(true) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 381ae897..b555fbc4 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) + controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) + controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } @@ -408,3 +408,58 @@ func testLogLevelName() string { } return "info" } + +// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket. +// Using UDP here because it's a simpler protocol. +func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(fromPort), + DstPort: layers.UDPPort(toPort), + } + if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil { + panic(err) + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + serialize = append(serialize, &udp, gopacket.Payload(data)) + if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil { + panic(err) + } + + return buffer.Bytes() +} diff --git a/e2e/router/router.go b/e2e/router/router.go index c8264ab7..72012073 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -13,6 +13,7 @@ import ( "regexp" "sort" "sync" + "sync/atomic" "testing" "time" @@ -24,6 +25,19 @@ import ( "golang.org/x/exp/maps" ) +// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the +// allocation cost of a string-concat key. +type outNatKey struct { + from, to netip.AddrPort +} + +// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from +// the fan-in channel. +type fannedPacket struct { + from *nebula.Control + pkt *udp.Packet +} + type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? @@ -34,12 +48,28 @@ type R struct { // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender - // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]netip.AddrPort + outNat map[outNatKey]netip.AddrPort // A map of vpn ip to the nebula control it belongs to vpnControls map[netip.Addr]*nebula.Control + // Cached select infrastructure for RouteForAllUntilTxTun. + // The controls map is immutable after NewR so the cases are good for the test lifetime. + // We only rebuild if a different receiver is asked. + selRecvCtl *nebula.Control + selCases []reflect.SelectCase + selCtls []*nebula.Control + + // Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn, + // so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call. + // Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control). + // Enabled by EnableFanIn. + udpFanIn chan fannedPacket + stopFanIn chan struct{} + fanInWG sync.WaitGroup + fanInMu sync.Mutex + fanInOn atomic.Bool + ignoreFlows []ignoreFlow flow []flowEntry @@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { controls: make(map[netip.AddrPort]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control), - outNat: make(map[string]netip.AddrPort), + outNat: make(map[outNatKey]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { case <-ctx.Done(): return case <-clockSource.C: + r.Lock() r.renderHostmaps("clock tick") r.renderFlow() + r.Unlock() } } }() @@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { // RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening. func (r *R) RenderFlow() { r.cancelRender() + r.Lock() + defer r.Unlock() r.renderFlow() } // CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected func (r *R) CancelFlowLogs() { r.cancelRender() + r.Lock() r.flow = nil + r.Unlock() } +// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and +// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths. func (r *R) renderFlow() { if r.flow == nil { return @@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) - c.InjectUDPPacket(p) + c.InjectUDPPacket(p) // copies internally; original is ours to release fp.WasReceived() r.Unlock() + p.Release() } } } -// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun -// If the router doesn't have the nebula controller for that address, we panic +// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun. +// If a control's UDP TX address can't be matched to a registered control, we panic. +// +// For allocation-sensitive callers (hot-path benchmarks, in particular relay +// benches with 3+ controls), call EnableFanIn() first. func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { + if r.fanInOn.Load() { + return r.routeFanIn(receiver) + } + return r.routeReflect(receiver) +} + +// routeFanIn is the alloc-free path used when EnableFanIn is in effect. +func (r *R) routeFanIn(receiver *nebula.Control) []byte { + tunTx := receiver.GetTunTxChan() + for { + select { + case p := <-tunTx: + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(receiver, receiver, &np, true) + } + r.Unlock() + return p + case fp := <-r.udpFanIn: + r.routeUDP(fp.from, fp.pkt) + } + } +} + +// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere +// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP. +func (r *R) routeReflect(receiver *nebula.Control) []byte { + sc, cm := r.selectCasesFor(receiver) + for { + x, rx, _ := reflect.Select(sc) + if x == 0 { + p := rx.Interface().([]byte) + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(cm[x], cm[x], &np, true) + } + r.Unlock() + return p + } + r.routeUDP(cm[x], rx.Interface().(*udp.Packet)) + } +} + +// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path. +// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects +// on alongside the receiver's TUN TX channel. +func (r *R) EnableFanIn() { + r.fanInMu.Lock() + defer r.fanInMu.Unlock() + if r.fanInOn.Load() { + return + } + r.udpFanIn = make(chan fannedPacket, 32) + r.stopFanIn = make(chan struct{}) + for _, c := range r.controls { + r.startFanInWorker(c) + } + r.fanInOn.Store(true) + r.t.Cleanup(r.stopFanInWorkers) +} + +// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn. +func (r *R) startFanInWorker(c *nebula.Control) { + r.fanInWG.Add(1) + udpTx := c.GetUDPTxChan() + go func() { + defer r.fanInWG.Done() + for { + select { + case <-r.stopFanIn: + return + case p := <-udpTx: + select { + case <-r.stopFanIn: + p.Release() + return + case r.udpFanIn <- fannedPacket{from: c, pkt: p}: + } + } + } + }() +} + +// stopFanInWorkers signals the fan-in goroutines to exit and waits for them. +func (r *R) stopFanInWorkers() { + r.fanInMu.Lock() + wasOn := r.fanInOn.Swap(false) + r.fanInMu.Unlock() + if !wasOn { + return + } + close(r.stopFanIn) + r.fanInWG.Wait() +} + +// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To, +// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot. +func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) { + r.Lock() + defer r.Unlock() + a := from.GetUDPAddr() + c := r.getControl(a, p.To, p) + if c == nil { + panic(fmt.Sprintf("No control for udp tx %s", p.To)) + } + fp := r.unlockedInjectFlow(from, c, p, false) + c.InjectUDPPacket(p) // copies internally; original is ours to release + fp.WasReceived() + p.Release() +} + +// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed +// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes. +func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) { + r.Lock() + defer r.Unlock() + if r.selRecvCtl == receiver && r.selCases != nil { + return r.selCases, r.selCtls + } sc := make([]reflect.SelectCase, len(r.controls)+1) cm := make([]*nebula.Control, len(r.controls)+1) - - i := 0 - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(receiver.GetTunTxChan()), - Send: reflect.Value{}, - } - cm[i] = receiver - - i++ + sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())} + cm[0] = receiver + i := 1 for _, c := range r.controls { - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(c.GetUDPTxChan()), - Send: reflect.Value{}, - } - + sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())} cm[i] = c i++ } - - for { - x, rx, _ := reflect.Select(sc) - r.Lock() - - if x == 0 { - // we are the tun tx, we can exit - p := rx.Interface().([]byte) - np := udp.Packet{Data: make([]byte, len(p))} - copy(np.Data, p) - - r.unlockedInjectFlow(cm[x], cm[x], &np, true) - r.Unlock() - return p - - } else { - // we are a udp tx, route and continue - p := rx.Interface().(*udp.Packet) - a := cm[x].GetUDPAddr() - c := r.getControl(a, p.To, p) - if c == nil { - r.Unlock() - panic(fmt.Sprintf("No control for udp tx %s", p.To)) - } - fp := r.unlockedInjectFlow(cm[x], c, p, false) - c.InjectUDPPacket(p) - fp.WasReceived() - } - r.Unlock() - } + r.selRecvCtl = receiver + r.selCases = sc + r.selCtls = cm + return sc, cm } // RouteExitFunc will call the whatDo func with each udp packet from sender. @@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { } r.Unlock() + p.Release() } } @@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) } r.Unlock() + p.Release() } } @@ -702,19 +835,20 @@ func (r *R) FlushAll() { } receiver.InjectUDPPacket(p) r.Unlock() + p.Release() } } // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok { p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr + r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr return c } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 63c655f3..697f25af 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -355,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) t.Log("reply?") - theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index b2c2a0ea..8acd83f0 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -15,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" + "github.com/slackhq/nebula/udp" ) type TestTun struct { @@ -54,9 +55,12 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTu return nil, fmt.Errorf("newTunFromFd not supported") } -// Send will place a byte array onto the receive queue for nebula to consume +// Send will place a byte array onto the receive queue for nebula to consume. // These are unencrypted ip layer frames destined for another nebula node. -// packets should exit the udp side, capture them with udpConn.Get +// packets should exit the udp side, capture them with udpConn.Get. +// +// Send copies the input via the freelist, so the caller is free to mutate +// or reuse it after the call returns. func (t *TestTun) Send(packet []byte) { if t.closed.Load() { return @@ -65,7 +69,9 @@ func (t *TestTun) Send(packet []byte) { if t.l.Enabled(context.Background(), slog.LevelDebug) { t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } - t.rxPackets <- packet + buf := acquireTunBuf(len(packet)) + copy(buf, packet) + t.rxPackets <- buf } // Get will pull an unencrypted ip layer frame from the transmit queue @@ -110,12 +116,44 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe } - packet := make([]byte, len(b), len(b)) + packet := acquireTunBuf(len(b)) copy(packet, b) t.TxPackets <- packet return len(b), nil } +// ReleaseTunBuf returns a slice from TxPackets to the harness freelist, don't use the bytes after the call. +// Channel-backed instead of sync.Pool because putting a []byte in a sync.Pool escapes the slice header to heap. +func ReleaseTunBuf(b []byte) { + if b == nil { + return + } + select { + case tunBufFreelist <- b: + default: + // Freelist full; drop the buffer for the GC. + } +} + +// tunBufFreelist retains the backing arrays for TestTun.Write so steady-state allocation drops to zero once the +// freelist has saturated for the current MTU. +var tunBufFreelist = make(chan []byte, 64) + +func acquireTunBuf(n int) []byte { + var b []byte + select { + case b = <-tunBufFreelist: + default: + b = make([]byte, 0, udp.MTU) + } + if cap(b) < n { + b = make([]byte, n) + } else { + b = b[:n] + } + return b +} + func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) @@ -129,8 +167,14 @@ func (t *TestTun) Read(b []byte) (int, error) { if !ok { return 0, os.ErrClosed } + n := len(p) copy(b, p) - return len(p), nil + // Send always pushes a freelist-acquired slice, return it once we've copied the bytes into the caller's buffer. + select { + case tunBufFreelist <- p: + default: + } + return n, nil } func (t *TestTun) SupportsMultiqueue() bool { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index fcd0967c..f872e32a 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -21,17 +21,48 @@ type Packet struct { Data []byte } +// Copy returns a fresh *Packet (from the freelist) with a duplicate Data buffer. func (u *Packet) Copy() *Packet { - n := &Packet{ - To: u.To, - From: u.From, - Data: make([]byte, len(u.Data)), + n := acquirePacket() + n.To = u.To + n.From = u.From + if cap(n.Data) < len(u.Data) { + n.Data = make([]byte, len(u.Data)) + } else { + n.Data = n.Data[:len(u.Data)] } - copy(n.Data, u.Data) return n } +// Release returns p to the harness packet freelist. +// Callers that pull a *Packet from Get / TxPackets must Release when done. +// Channel-backed instead of sync.Pool because sync.Pool's per-P caches drain badly under cross-goroutine Get/Put, +// and putting a []byte in a Pool escapes the slice header to heap. +func (p *Packet) Release() { + if p == nil { + return + } + p.Data = p.Data[:0] + select { + case packetFreelist <- p: + default: + // Freelist full; drop the *Packet for the GC. + } +} + +// packetFreelist retains *Packet structs (and their backing Data arrays) so steady-state allocation drops to zero. +var packetFreelist = make(chan *Packet, 64) + +func acquirePacket() *Packet { + select { + case p := <-packetFreelist: + return p + default: + return &Packet{} + } +} + type TesterConn struct { Addr netip.AddrPort @@ -64,13 +95,15 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - h := &header.H{} - if err := h.Parse(packet.Data); err != nil { - panic(err) - } if u.l.Enabled(context.Background(), slog.LevelDebug) { + // Parse the header only under debug logging, otherwise the + // allocation would show up in every Send call. + var h header.H + if err := h.Parse(packet.Data); err != nil { + panic(err) + } u.l.Debug("UDP receiving injected packet", - "header", h, + "header", &h, "udpAddr", packet.From, "dataLen", len(packet.Data), ) @@ -107,15 +140,18 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - p := &Packet{ - Data: make([]byte, len(b), len(b)), - From: u.Addr, - To: addr, + p := acquirePacket() + if cap(p.Data) < len(b) { + p.Data = make([]byte, len(b)) + } else { + p.Data = p.Data[:len(b)] } - copy(p.Data, b) + p.From = u.Addr + p.To = addr select { case <-u.done: + p.Release() return io.ErrClosedPipe case u.TxPackets <- p: return nil @@ -129,6 +165,7 @@ func (u *TesterConn) ListenOut(r EncReader) error { return os.ErrClosed case p := <-u.RxPackets: r(p.From, p.Data) + p.Release() } } } From ff91c37529509ffb26137bff4d4ded9eac9113a6 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 6 May 2026 10:22:26 -0500 Subject: [PATCH 43/44] switch Bits to a packed u64 (#1705) --- bits.go | 209 +++++++++++++++++++++----- bits_test.go | 407 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 452 insertions(+), 164 deletions(-) diff --git a/bits.go b/bits.go index 5c8f902b..15bafd87 100644 --- a/bits.go +++ b/bits.go @@ -2,24 +2,42 @@ package nebula import ( "context" + "fmt" "log/slog" + "math" + mathbits "math/bits" "github.com/rcrowley/go-metrics" ) +const bitsPerWord = 64 + +// Bits is a sliding-window anti-replay tracker. The window is stored as a +// circular bitmap packed into uint64 words (8x denser than a []bool), so a +// length-N window costs N/8 bytes. length must be a power of two. type Bits struct { length uint64 + lengthMask uint64 current uint64 - bits []bool + bits []uint64 lostCounter metrics.Counter dupeCounter metrics.Counter outOfWindowCounter metrics.Counter } -func NewBits(bits uint64) *Bits { +func NewBits(length uint64) *Bits { + if length == 0 || length&(length-1) != 0 { + panic(fmt.Sprintf("Bits length must be a power of two, got %d", length)) + } + + nWords := length / bitsPerWord + if nWords == 0 { + nWords = 1 + } b := &Bits{ - length: bits, - bits: make([]bool, bits, bits), + length: length, + lengthMask: length - 1, + bits: make([]uint64, nWords), current: 0, lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), @@ -27,71 +45,194 @@ func NewBits(bits uint64) *Bits { } // There is no counter value 0, mark it to avoid counting a lost packet later. - b.bits[0] = true - b.current = 0 + b.bits[0] = 1 return b } +func (b *Bits) get(i uint64) bool { + pos := i & b.lengthMask + //bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it + return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0 +} + +func (b *Bits) set(i uint64) { + pos := i & b.lengthMask + b.bits[pos>>6] |= uint64(1) << (pos & 63) +} + +// clearRange clears `count` bits starting at circular position `startPos` +// (already masked to [0, length)) and returns how many of them were set +// before the clear. count must be in [1, length]. +func (b *Bits) clearRange(startPos, count uint64) uint64 { + wasSet := uint64(0) + if count >= b.length { + for _, w := range b.bits { + wasSet += uint64(mathbits.OnesCount64(w)) + } + clear(b.bits) + return wasSet + } + + pos := startPos + remaining := count + + // handle the potential partial word before pos becomes u64 aligned + word := pos >> 6 + bit := pos & 63 + take := uint64(64) - bit + if take > remaining { + take = remaining + } + if take > b.length-pos { + take = b.length - pos + } + var mask uint64 + if take == 64 { + mask = math.MaxUint64 + } else { + mask = ((uint64(1) << take) - 1) << bit + } + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + remaining -= take + pos = (pos + take) & b.lengthMask + + // Clear whole words, keeping track of the number of set bits + for remaining >= 64 { + word = pos >> 6 + wasSet += uint64(mathbits.OnesCount64(b.bits[word])) + b.bits[word] = 0 + remaining -= 64 + pos = (pos + 64) & b.lengthMask + } + + // Clear the remaining partial word + if remaining > 0 { + word = pos >> 6 + mask = (uint64(1) << remaining) - 1 + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + } + + return wasSet +} + +func (b *Bits) strictlyWithinWindow(i uint64) bool { + // Handle the case where the window hasn't slid yet. This avoids u64 underflow. + inWarmup := b.current < b.length + if i < b.length && inWarmup { + return true + } + + // Next, if the packet is in-window, see if we've seen it before + if i > b.current-b.length { + return true + } + return false //not within window! +} + +// Check returns true if i is within (or way out in front of) the window, and not a replay func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true } - // If i is within the window, check if it's been set already. - if i > b.current-b.length || i < b.length && b.current < b.length { - return !b.bits[i%b.length] + if b.strictlyWithinWindow(i) { + return !b.get(i) } // Not within the window if l.Enabled(context.Background(), slog.LevelDebug) { - l.Debug("rejected a packet (top)", - "current", b.current, - "incoming", i, - ) + l.Debug("rejected a packet (top)", "current", b.current, "incoming", i) } return false } +// Update has three branches: +// - i == b.current+1: fast path; advance the cursor by one and lose-count +// the slot we just stomped (only past warmup; see the i > b.length guard +// below). +// - i > b.current+1: jump path; clear all slots between current and i +// (or up to a full window's worth, whichever is smaller) via clearRange, +// then mark i. Two arms here: a warmup arm that handles the very first +// window before the cursor has slid, and a steady-state arm that treats +// every cleared empty slot as a lost packet. +// - i <= b.current: in-window check for duplicates; out-of-window otherwise. +// +// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never +// clears that marker during warmup (clearRange skips position 0 when +// startPos=1), and once b.current >= b.length the marker is no longer +// consulted. The marker prevents a fictitious "lost" hit on the first real +// counter. func (b *Bits) Update(l *slog.Logger, i uint64) bool { - // If i is the next number, return true and update current. + // Fast path: i is the next expected counter. Split out so the function + // stays small and avoids paying for the slow paths' slog argument-build + // stack frame on every call. The bit read/test/write is inlined to + // touch the backing word once. if i == b.current+1 { - // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter - // The very first window can only be tracked as lost once we are on the 2nd window or greater - if b.bits[i%b.length] == false && i > b.length { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if i > b.length && w&mask == 0 { b.lostCounter.Inc(1) } - b.bits[i%b.length] = true + b.bits[word] = w | mask b.current = i return true } + return b.updateSlow(l, i) +} +// updateSlow handles jumps, in-window backfill, dupes, and out-of-window. +func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool { // If i is a jump, adjust the window, record lost, update current, and return true if i > b.current { - lost := int64(0) - // Zero out the bits between the current and the new counter value, limited by the window size, - // since the window is shifting - for n := b.current + 1; n <= min(i, b.current+b.length); n++ { - if b.bits[n%b.length] == false && n > b.length { - lost++ + end := i + if end > b.current+b.length { + end = b.current + b.length + } + count := end - b.current + startPos := (b.current + 1) & b.lengthMask + + var lost int64 + if b.current >= b.length { + // Steady state: every cleared slot is past warmup, so any unset + // bit we evict is a lost packet from the previous cycle. + wasSet := b.clearRange(startPos, count) + lost = int64(count) - int64(wasSet) + } else { + // Warmup (the very first window). Some cleared slots represent + // packets <= length where eviction is not "lost" in the usual + // sense. This branch is taken at most once per connection so we + // don't bother optimizing it. + for n := b.current + 1; n <= end; n++ { + if !b.get(n) && n > b.length { + lost++ + } } - b.bits[n%b.length] = false + b.clearRange(startPos, count) } - // Only record any skipped packets as a result of the window moving further than the window length - // Any loss within the new window will be accounted for in future calls - lost += max(0, int64(i-b.current-b.length)) + // Anything past the new window can never be backfilled, so it's lost. + if i > b.current+b.length { + lost += int64(i - b.current - b.length) + } b.lostCounter.Inc(lost) - b.bits[i%b.length] = true + b.set(i) b.current = i return true } - // If i is within the current window but below the current counter, - // Check to see if it's a duplicate - if i > b.current-b.length || i < b.length && b.current < b.length { - if b.current == i || b.bits[i%b.length] == true { + // If i is within the current window but below the current counter, check to see if it's a duplicate + if b.strictlyWithinWindow(i) { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if b.current == i || w&mask != 0 { if l.Enabled(context.Background(), slog.LevelDebug) { l.Debug("Receive window", "accepted", false, @@ -104,7 +245,7 @@ func (b *Bits) Update(l *slog.Logger, i uint64) bool { return false } - b.bits[i%b.length] = true + b.bits[word] = w | mask return true } diff --git a/bits_test.go b/bits_test.go index 3504cefa..da44c92a 100644 --- a/bits_test.go +++ b/bits_test.go @@ -7,61 +7,79 @@ import ( "github.com/stretchr/testify/assert" ) +// snapshot returns the bitmap as a []bool of length b.length, for readable +// test assertions against the now-packed []uint64 storage. +func (b *Bits) snapshot() []bool { + out := make([]bool, b.length) + for i := uint64(0); i < b.length; i++ { + out[i] = b.get(i) + } + return out +} + +func TestBitsRequiresPowerOfTwo(t *testing.T) { + assert.Panics(t, func() { NewBits(10) }) + assert.Panics(t, func() { NewBits(0) }) + assert.NotPanics(t, func() { NewBits(1) }) + assert.NotPanics(t, func() { NewBits(16) }) + assert.NotPanics(t, func() { NewBits(1024) }) + assert.NotPanics(t, func() { NewBits(16384) }) +} + func TestBits(t *testing.T) { l := test.NewLogger() - b := NewBits(10) - - // make sure it is the right size - assert.Len(t, b.bits, 10) + b := NewBits(16) + assert.EqualValues(t, 16, b.length) // This is initialized to zero - receive one. This should work. assert.True(t, b.Check(l, 1)) assert.True(t, b.Update(l, 1)) assert.EqualValues(t, 1, b.current) - g := []bool{true, true, false, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two assert.True(t, b.Check(l, 2)) assert.True(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - g = []bool{true, true, true, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two again - it will fail assert.False(t, b.Check(l, 2)) assert.False(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - // Jump ahead to 15, which should clear everything and set the 6th element - assert.True(t, b.Check(l, 15)) - assert.True(t, b.Update(l, 15)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, false, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Jump ahead to 25, which clears the window and sets slot 25%16 = 9. + assert.True(t, b.Check(l, 25)) + assert.True(t, b.Update(l, 25)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 14, which is allowed because it is in the window - assert.True(t, b.Check(l, 14)) - assert.True(t, b.Update(l, 14)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Mark 24, which is in window (current 25, length 16, window covers [10,25]). + assert.True(t, b.Check(l, 24)) + assert.True(t, b.Update(l, 24)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 5, which is not allowed because it is not in the window + // Mark 5, not allowed because 5 <= current-length (25-16=9). assert.False(t, b.Check(l, 5)) assert.False(t, b.Update(l, 5)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // make sure we handle wrapping around once to the current position - b = NewBits(10) + // Make sure we handle wrapping around once to the same slot. With + // length=16, packets 1 and 17 share slot 1. + b = NewBits(16) assert.True(t, b.Update(l, 1)) - assert.True(t, b.Update(l, 11)) - assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) + assert.True(t, b.Update(l, 17)) + assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot()) // Walk through a few windows in order - b = NewBits(10) + b = NewBits(16) for i := uint64(1); i <= 100; i++ { assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i) @@ -72,24 +90,31 @@ func TestBits(t *testing.T) { func TestBitsLargeJumps(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + + // length=16. Update(55) from current=0: + // warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by + // NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16), + // so the loop contributes 0. The jump exceeds the window so we record + // 55 - 0 - 16 = 39 packets fell out the back. + b := NewBits(16) b.lostCounter.Clear() + assert.True(t, b.Update(l, 55)) + assert.Equal(t, int64(39), b.lostCounter.Count()) - b = NewBits(10) - b.lostCounter.Clear() - assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 - assert.Equal(t, int64(45), b.lostCounter.Count()) + // Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for + // packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits. + // Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44. + assert.True(t, b.Update(l, 100)) + assert.Equal(t, int64(39+44), b.lostCounter.Count()) - assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 - assert.Equal(t, int64(89), b.lostCounter.Count()) - - assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199 - assert.Equal(t, int64(188), b.lostCounter.Count()) + // Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99. + assert.True(t, b.Update(l, 200)) + assert.Equal(t, int64(39+44+99), b.lostCounter.Count()) } func TestBitsDupeCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() @@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Jump to 20 (warmup branch + 4 past-window packets). assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) + // 9 single-step advances, each evicts a slot whose bit was cleared during + // the jump above and whose value was never seen, so each contributes 1 + // to lostCounter. + for n := uint64(21); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + // 0 is below current-length (29-16=13) so it falls outside the window. assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // 4 from the Update(20) jump + 9 from 21..29. + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) } func TestBitsLostCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 20)) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // Walk 20..29 like the original, just with a bigger window. Same + // reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20), + // then 9 more from the unit advances. + for n := uint64(20); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - b = NewBits(10) + b = NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 9)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 10 will set 0 index, 0 was already set, no lost packets - assert.True(t, b.Update(l, 10)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 11 will set 1 index, 1 was missed, we should see 1 packet lost - assert.True(t, b.Update(l, 11)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - // Now let's fill in the window, should end up with 8 lost packets - assert.True(t, b.Update(l, 12)) - assert.True(t, b.Update(l, 13)) - assert.True(t, b.Update(l, 14)) + // Update(15) clears the warmup window (no lost), sets slot 15. assert.True(t, b.Update(l, 15)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(16): slot 0 was already set (NewBits seeded it), and 16 is not + // strictly > length, so nothing is recorded as lost. assert.True(t, b.Update(l, 16)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(17): we jumped straight from 0 to 15, so slot 1 was cleared + // (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost. assert.True(t, b.Update(l, 17)) - assert.True(t, b.Update(l, 18)) - assert.True(t, b.Update(l, 19)) - assert.Equal(t, int64(8), b.lostCounter.Count()) + assert.Equal(t, int64(1), b.lostCounter.Count()) - // Jump ahead by a window size - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(8), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in - assert.True(t, b.Update(l, 30)) - assert.True(t, b.Update(l, 31)) - assert.True(t, b.Update(l, 32)) - assert.True(t, b.Update(l, 33)) - assert.True(t, b.Update(l, 34)) - assert.True(t, b.Update(l, 35)) - assert.True(t, b.Update(l, 36)) - assert.True(t, b.Update(l, 37)) - assert.True(t, b.Update(l, 38)) - // 39 packets tracked, 22 seen, 17 lost - assert.Equal(t, int64(17), b.lostCounter.Count()) + // Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14 + // were all cleared during Update(15), and we never re-set any of them, + // so each i in 18..30 is a fresh lost packet — 13 more. + for n := uint64(18); n <= 30; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(14), b.lostCounter.Count()) - // Jump ahead by 2 windows, should have recording 1 full window missing - assert.True(t, b.Update(l, 58)) - assert.Equal(t, int64(27), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in from this window - assert.True(t, b.Update(l, 59)) - assert.True(t, b.Update(l, 60)) - assert.True(t, b.Update(l, 61)) - assert.True(t, b.Update(l, 62)) - assert.True(t, b.Update(l, 63)) - assert.True(t, b.Update(l, 64)) - assert.True(t, b.Update(l, 65)) - assert.True(t, b.Update(l, 66)) - assert.True(t, b.Update(l, 67)) - // 68 packets tracked, 32 seen, 36 missed - assert.Equal(t, int64(36), b.lostCounter.Count()) + // Jump ahead by exactly one window size. + assert.True(t, b.Update(l, 46)) + // end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the + // jump every slot 0..15 had been set (Update(15), (16), (17), 18..30), + // so wasSet=16 and 46 == current+length means no past-window slack: + // lost contribution = 0. + assert.Equal(t, int64(14), b.lostCounter.Count()) + + // Walk 47..55. The Update(46) jump cleared every slot, so only slot 14 + // (for packet 46) is set when we start. Each subsequent unit step lands + // on a slot that was cleared and is past warmup, so it counts as lost. + // 9 more = 23. + for n := uint64(47); n <= 55; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(23), b.lostCounter.Count()) + + // Jump ahead by two windows: clears the window plus past-window loss. + assert.True(t, b.Update(l, 87)) + // current=55, length=16. end = min(87, 71) = 71. count=16, all slots + // cleared. Slots set before the clear are slots 14,15,0..7 (10 total). + // Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22. + assert.Equal(t, int64(45), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsLostCounterIssue1(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14. + // Then jump to 25 — slot 25%16=9 is being evicted, but it had been set + // (we received packet 9), so no spurious lost increment. The original + // regression was about double-counting a missing packet when its slot + // got cleared on a jump. With the jump path now using clearRange's + // word-level wasSet count, the same semantics hold. assert.True(t, b.Update(l, 4)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 1)) @@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 7)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // assert.True(t, b.Update(l, 8)) + // Skip packet 8. assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 11)) @@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.True(t, b.Update(l, 14)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter - assert.True(t, b.Update(l, 19)) + + // Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9 + // (which we DID receive), so its bit is set and no lost++ from that + // eviction. The trace below shows the only loss is packet 8. + assert.True(t, b.Update(l, 25)) + // current was 14, i=25. end=min(25,30)=25. count=11. startPos=15. + // steady? current=14<16, so warmup branch: per-bit n=15..25, count those + // with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9 + // did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8 + // was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other + // n in 17..25 map to slots that are set. n=16 is not strictly > 16. So + // lost = 1. assert.Equal(t, int64(1), b.lostCounter.Count()) + + // Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must + // recheck slot 0 — it was set by NewBits and then cleared by the + // Update(25) jump, so 16 backfills cleanly. assert.True(t, b.Update(l, 12)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 13)) @@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 16)) assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 17)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 18)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 20)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 21)) - // We missed packet 8 above + // We missed packet 8 above and that loss is still recorded once, never + // double-counted, never zeroed. assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } -func BenchmarkBits(b *testing.B) { - z := NewBits(10) - for n := 0; n < b.N; n++ { - for i := range z.bits { - z.bits[i] = true - } - for i := range z.bits { - z.bits[i] = false - } +// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an +// overshoot past one full window. NewBits leaves current=0 with only slot 0 +// "set" by the marker. Jumping straight to length+k must (a) clear every +// slot the jump straddles, (b) count only past-window slack (not the +// in-window slots, which never had a "lost" tenant during warmup), and +// (c) leave the cursor at the new counter so subsequent unit advances +// count from steady state. The marker bit at slot 0 is irrelevant once +// current >= length. +func TestBitsWarmupOvershoot(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + b.lostCounter.Clear() + // Jump from current=0 to i=20 (length=16, overshoot=4). + // Warmup arm: counts slots in [1..16] where bit unset and n>length. + // Only n=16 was unset and >length: but slot 16%16=0 is the marker, + // so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop. + // Past-window: i - current - length = 20 - 0 - 16 = 4 lost. + assert.True(t, b.Update(l, 20)) + assert.Equal(t, int64(4), b.lostCounter.Count()) + assert.Equal(t, uint64(20), b.current) + + // Steady state now (current=20 >= length=16). Unit advance to 21 + // stomps slot 21%16=5, which was cleared by the jump and not reset, + // so this is +1 lost. + assert.True(t, b.Update(l, 21)) + assert.Equal(t, int64(5), b.lostCounter.Count()) +} + +// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's +// in-window clause. While in warmup, b.current-b.length underflows uint64 +// to a huge value so the first OR-clause is always false; the second +// clause (i < length && current < length) carries the in-window check. +// Once current >= length the regimes flip cleanly. +func TestBitsCheckAcrossWarmupBoundary(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + + // Warmup: current=0. Check(0) must read the marker (set) and return false. + assert.False(t, b.Check(l, 0), "marker slot should look already-received") + // Warmup: any 0 < i < length is in-window and unset → accepted. + for i := uint64(1); i < 16; i++ { + assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i) + } + // Warmup: i >= length but > current is "next number" so accepted. + assert.True(t, b.Check(l, 16)) + assert.True(t, b.Check(l, 1_000_000)) + + // Cross into steady state. + assert.True(t, b.Update(l, 100)) + // Now current=100, length=16. In-window range is [85..100]. + // 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false. + // And the warmup clause is false (current >= length). So out of window. + assert.False(t, b.Check(l, 84)) + // 85 sits at the boundary. 85 > 84 is true → in window, unset → accept. + assert.True(t, b.Check(l, 85)) + // 100 is current itself; not strictly greater, in-window, but already set. + assert.False(t, b.Check(l, 100)) + // Way out: clearly out of window. + assert.False(t, b.Check(l, 50)) +} + +// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves +// correctly across warmup and beyond. Update should never clear the marker +// during warmup (clearRange skips position 0 when startPos=1), and once +// current >= length the marker is no longer consulted by Check/Update on +// the live path — but it must still report counter 0 as a duplicate while +// we are in warmup. +func TestBitsMarkerInvariant(t *testing.T) { + l := test.NewLogger() + b := NewBits(8) + + // Counter 0 is the seeded marker; Check sees it as already received. + assert.False(t, b.Check(l, 0)) + // Update(0) at current=0 hits the duplicate branch. + b.dupeCounter.Clear() + assert.False(t, b.Update(l, 0)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + // Walk forward through warmup; the marker must remain set. + for n := uint64(1); n <= 7; n++ { + assert.True(t, b.Update(l, n)) + } + // Position 0 (the marker) should still read as set because we never + // cleared it; Update(0) still looks like a duplicate. + assert.False(t, b.Check(l, 0)) + + // Cross into steady state with a unit advance to 8: pos=0, evicts the + // marker bit. The lost-counter guard (i > b.length) is false (8 == 8), + // so this advance does NOT charge a lost packet — exactly what the + // marker is there to prevent. + b.lostCounter.Clear() + assert.True(t, b.Update(l, 8)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // The slot at pos 0 is now occupied by counter 8. + assert.False(t, b.Check(l, 8)) +} + +// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is +// i == current+1. +func BenchmarkBitsUpdateInOrder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n)+1) + } +} + +// BenchmarkBitsUpdateReorder simulates light reorder within the window: +// every other packet arrives one slot behind its predecessor (forces the +// in-window backfill branch). +func BenchmarkBitsUpdateReorder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + base := uint64(n) * 2 + z.Update(l, base+2) + z.Update(l, base+1) + } +} + +// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path. +func BenchmarkBitsUpdateLargeJumps(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n+1)*1000) } } From 4fb5cdb4faaa1c47ef0c8e59fb46641db707dca9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 6 May 2026 12:23:27 -0400 Subject: [PATCH 44/44] refactor readOutsidePackets (#1642) * refactor readOutsidePackets They layout of this method is confusing and relys on certain parts to return early for things to work correctly. Change the ordering of the logic so that we do this: - Handle unencrypted packets - Decrypt packet - Handle encrypted packets This way, nothing can sneak through unencrypted to where it shouldn't be. * fix comment * code review comments * check for expected type/subtype * check header version * log header * need to handle TestReply * clean roaming / connectionManager * dont need to roam here now, we do it earlier * cleanup metrics and errors * rxInvalid * debug logger checks * ErrOutOfWindow --- header/header.go | 14 ++ message_metrics.go | 8 + outside.go | 413 +++++++++++++++++++++------------------------ 3 files changed, 210 insertions(+), 225 deletions(-) diff --git a/header/header.go b/header/header.go index f22509b8..b973141f 100644 --- a/header/header.go +++ b/header/header.go @@ -174,6 +174,10 @@ func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } +func (h *H) IsValidSubType() bool { + return IsValidSubType(h.Type, h.Subtype) +} + // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { @@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string { return "unknown" } +func IsValidSubType(t MessageType, s MessageSubType) bool { + if n, ok := subTypeMap[t]; ok { + if _, ok := (*n)[s]; ok { + return true + } + } + + return false +} + // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) diff --git a/message_metrics.go b/message_metrics.go index 10e8472c..45de9a5c 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -13,6 +13,8 @@ type MessageMetrics struct { rxUnknown metrics.Counter txUnknown metrics.Counter + + rxInvalid metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { @@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int } } } +func (m *MessageMetrics) RxInvalid(i int64) { + if m != nil && m.rxInvalid != nil { + m.rxInvalid.Inc(i) + } +} func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { @@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics { rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), + rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil), } } diff --git a/outside.go b/outside.go index 1e00a0a9..17013ed3 100644 --- a/outside.go +++ b/outside.go @@ -20,23 +20,46 @@ const ( minFwPacketLen = 4 ) +var ErrOutOfWindow = errors.New("out of window packet") + func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + // TODO: record metrics for rx holepunch/punchy packets? if len(packet) > 1 { - f.l.Info("Error while parsing inbound packet", - "from", via, - "error", err, - "packet", packet, - ) + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) + } + } + return + } + + if h.Version != header.Version { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected header version received", "from", via) + } + return + } + + // Check before processing to see if this is a expected type/subtype + if !h.IsValidSubType() { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected packet received", "from", via) } return } - //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { + f.messageMetrics.RxInvalid(1) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Refusing to process double encrypted packet", "from", via) } @@ -44,215 +67,192 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } + // don't keep Rx metrics for message type, since you can see those in the tun metrics + if h.Type != header.Message { + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + } + + // Unencrypted packets + switch h.Type { + case header.Handshake: + f.handshakeManager.HandleIncoming(via, packet, h) + return + + case header.RecvError: + f.handleRecvError(via.UdpAddr, h) + return + } + + // Relay packets are special + isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay) + var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { + if isMessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState + // At this point we should have a valid existing tunnel, verify and send + // recvError if necessary + if hostinfo == nil || hostinfo.ConnectionState == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) + } + return } + // All remaining packets are encrypted + ci := hostinfo.ConnectionState + if !ci.window.Check(f.l, h.MessageCounter) { + return + } + + // Relay packets are special + if isMessageRelay { + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) + + return + } + + out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Failed to decrypt packet", + "error", err, + "from", via, + "header", h, + ) + } + return + } + + // Roam before we respond + f.handleHostRoaming(hostinfo, via) + f.connectionManager.In(hostinfo) + switch h.Type { case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, - ) - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).Info("Failed to find target host info by ip", - "relayTo", relay.PeerAddr, - "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, - ) - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).Info("Unexpected target relay state", - "relayTo", relay.PeerAddr, - "relayFrom", hostinfo.vpnAddrs[0], - "targetRelayState", targetRelay.State, - ) - return - } - } + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) + return } case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f) case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { + switch h.Subtype { + case header.TestReply: + // No-op, useful for the Roaming and connectionManager side-effects above + case header.TestRequest: + f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt test packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) - f.closeTunnel(hostinfo) - return case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt Control packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, out, f) default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) - } + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h) + } +} + +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } - f.handleHostRoaming(hostinfo, via) + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) + return + } - f.connectionManager.In(hostinfo) + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + return + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type) + } + return + } + } else { + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) + return + } + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type) + } + } } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -300,23 +300,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { } -// handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { - // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect - if ci == nil { - if !via.IsRelayed { - f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) - } - return false - } - // If the window check fails, refuse to process the packet, but don't send a recv error - if !ci.window.Check(f.l, h.MessageCounter) { - return false - } - - return true -} - var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") @@ -523,38 +506,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) - } - return nil, errors.New("out of window packet") + return nil, ErrOutOfWindow } return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) - return false - } - - err = newPacket(out, true, fwPacket) +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { + err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, "packet", out, ) - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) - } - return false + return } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) @@ -568,15 +533,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out "reason", dropReason, ) } - return false + return } - f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } - return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {