From f7540ad3556b5a4a9b4bf846b1a3f8e885468f64 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Fri, 7 Mar 2025 14:37:07 -0600 Subject: [PATCH 01/44] Remove commented out metadata.go (#1320) --- metadata.go | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 metadata.go diff --git a/metadata.go b/metadata.go deleted file mode 100644 index 6a023ab..0000000 --- a/metadata.go +++ /dev/null @@ -1,18 +0,0 @@ -package nebula - -/* - -import ( - proto "google.golang.org/protobuf/proto" -) - -func HandleMetaProto(p []byte) { - m := &NebulaMeta{} - err := proto.Unmarshal(p, m) - if err != nil { - l.Debugf("problem unmarshaling meta message: %s", err) - } - //fmt.Println(m) -} - -*/ From 94e89a10453a0f33a96ae5cda30f4d28606858d2 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 10 Mar 2025 10:17:54 -0400 Subject: [PATCH 02/44] smoke-tests: guess the lighthouse container IP better (#1347) Currently we just assume you are using the default Docker bridge network config of `172.17.0.0/24`. This change works to try to detect if you are using a different config, but still only works if you are using a `/24` and aren't running any other containers. A future PR could make this better by launching the lighthouse container first and then fetching what the IP address is before continuing with the configuration. --- .github/workflows/smoke/build.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index c546653..dcd132b 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,6 +5,10 @@ set -e -x rm -rf ./build mkdir ./build +# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 +# - We could make this better by launching the lighthouse first and then fetching what IP it is. +NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" + ( cd build @@ -21,16 +25,16 @@ mkdir ./build ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml From 612637f5290186c29e71f88ccfa9fcbda06e1666 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Mon, 10 Mar 2025 09:18:34 -0500 Subject: [PATCH 03/44] Fix `testifylint` lint errors (#1321) * Fix bool-compare * Fix empty * Fix encoded-compare * Fix error-is-as * Fix error-nil * Fix expected-actual * Fix len --- allow_list_test.go | 30 ++++----- cert/ca_pool_test.go | 118 ++++++++++++++++----------------- cert/cert_v1_test.go | 36 +++++----- cert/cert_v2_test.go | 30 ++++----- cert/crypto_test.go | 10 +-- cert/pem_test.go | 22 +++--- cert/sign_test.go | 12 ++-- cmd/nebula-cert/ca_test.go | 34 +++++----- cmd/nebula-cert/keygen_test.go | 14 ++-- cmd/nebula-cert/print_test.go | 6 +- cmd/nebula-cert/sign_test.go | 38 +++++------ cmd/nebula-cert/verify_test.go | 9 ++- config/config_test.go | 24 +++---- firewall_test.go | 86 ++++++++++++------------ handshake_manager_test.go | 2 +- header/header_test.go | 2 +- lighthouse_test.go | 2 +- outside_test.go | 20 +++--- overlay/route_test.go | 24 +++---- punchy_test.go | 16 ++--- 20 files changed, 267 insertions(+), 268 deletions(-) diff --git a/allow_list_test.go b/allow_list_test.go index c8b3d08..6d5e76b 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -98,7 +98,7 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) + assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) tree := new(bart.Table[bool]) tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) @@ -111,17 +111,17 @@ func TestAllowList_Allow(t *testing.T) { tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) + assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.True(t, al.Allow(netip.MustParseAddr("::1"))) + assert.False(t, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { - assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0")) + assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0")) rules := []AllowListNameRule{ {Name: regexp.MustCompile("^docker.*$"), Allow: false}, @@ -129,9 +129,9 @@ func TestLocalAllowList_AllowName(t *testing.T) { } al := &LocalAllowList{nameRules: rules} - assert.Equal(t, false, al.AllowName("docker0")) - assert.Equal(t, false, al.AllowName("tun0")) - assert.Equal(t, true, al.AllowName("eth0")) + assert.False(t, al.AllowName("docker0")) + assert.False(t, al.AllowName("tun0")) + assert.True(t, al.AllowName("eth0")) rules = []AllowListNameRule{ {Name: regexp.MustCompile("^eth.*$"), Allow: true}, @@ -139,7 +139,7 @@ func TestLocalAllowList_AllowName(t *testing.T) { } al = &LocalAllowList{nameRules: rules} - assert.Equal(t, false, al.AllowName("docker0")) - assert.Equal(t, true, al.AllowName("eth0")) - assert.Equal(t, true, al.AllowName("ens5")) + assert.False(t, al.AllowName("docker0")) + assert.True(t, al.AllowName("eth0")) + assert.True(t, al.AllowName("ens5")) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index f03b2ba..2f9255f 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -82,32 +82,32 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } p, err := NewCAPoolFromPEM([]byte(noNewLines)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") + assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) - assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") - assert.Equal(t, len(pppp.CAs), 3) + assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) + assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) - assert.Equal(t, len(ppppp.CAs), 1) + assert.Len(t, ppppp.CAs, 1) } func TestCertificateV1_Verify(t *testing.T) { @@ -118,7 +118,7 @@ func TestCertificateV1_Verify(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -126,7 +126,7 @@ func TestCertificateV1_Verify(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -138,7 +138,7 @@ func TestCertificateV1_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -150,9 +150,9 @@ func TestCertificateV1_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { @@ -163,7 +163,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -171,7 +171,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -183,7 +183,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -196,7 +196,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { @@ -205,7 +205,7 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -245,25 +245,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { @@ -272,7 +272,7 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -311,27 +311,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { @@ -342,7 +342,7 @@ func TestCertificateV2_Verify(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -350,7 +350,7 @@ func TestCertificateV2_Verify(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -362,7 +362,7 @@ func TestCertificateV2_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -374,9 +374,9 @@ func TestCertificateV2_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { @@ -387,7 +387,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -395,7 +395,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -407,7 +407,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -420,7 +420,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { @@ -429,7 +429,7 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -469,25 +469,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { @@ -496,7 +496,7 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -535,25 +535,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index 8c3fe93..ea98b08 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -39,14 +39,14 @@ func TestCertificateV1_Marshal(t *testing.T) { } b, err := nc.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, nc.Version(), Version1) - assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, Version1, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) @@ -99,8 +99,8 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( + assert.NoError(t, err) + assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", string(b), @@ -110,12 +110,12 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -123,22 +123,22 @@ func TestCertificateV1_VerifyPrivateKey(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -146,11 +146,11 @@ func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates @@ -182,11 +182,11 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { } b, err := nc.Marshal() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 3afbcab..6d55750 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -45,14 +45,14 @@ func TestCertificateV2_Marshal(t *testing.T) { nc.rawDetails = db b, err := nc.Marshal() - require.Nil(t, err) + require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, nc.Version(), Version2) - assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, Version2, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) @@ -121,8 +121,8 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { nc.rawDetails = rd b, err = nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( + assert.NoError(t, err) + assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", string(b), @@ -132,13 +132,13 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) assert.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) - require.Nil(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) @@ -148,7 +148,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) @@ -168,7 +168,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) assert.ErrorIs(t, err, ErrInvalidPrivateKey) @@ -193,12 +193,12 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -206,11 +206,11 @@ func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c9aba3e..c43eed7 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -61,7 +61,7 @@ qrlJ69wer3ZUHFXA // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) @@ -89,7 +89,7 @@ qrlJ69wer3ZUHFXA curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) assert.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) - assert.Equal(t, rest, []byte{}) + assert.Equal(t, []byte{}, rest) } func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { @@ -99,14 +99,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.Nil(t, err) + assert.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) - assert.Equal(t, rest, []byte{}) - assert.Nil(t, err) + assert.Equal(t, []byte{}, rest) + assert.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } diff --git a/cert/pem_test.go b/cert/pem_test.go index a0c6e74..9ad8a69 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -35,7 +35,7 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) @@ -84,14 +84,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) @@ -146,14 +146,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) @@ -200,9 +200,9 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) - assert.Equal(t, 32, len(k)) + assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key @@ -259,15 +259,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) - assert.Equal(t, 32, len(k)) - assert.Nil(t, err) + assert.Len(t, k, 32) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) - assert.Equal(t, 65, len(k)) - assert.Nil(t, err) + assert.Len(t, k, 65) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) diff --git a/cert/sign_test.go b/cert/sign_test.go index 2b8dbe8..30d8480 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -37,14 +37,14 @@ func TestCertificateV1_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, uc) } @@ -78,13 +78,13 @@ func TestCertificateV1_SignP256(t *testing.T) { rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, uc) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 9da0ad4..71b69be 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -112,8 +112,8 @@ func Test_ca(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) - assert.Nil(t, os.Remove(keyF.Name())) + assert.NoError(t, err) + assert.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() @@ -125,15 +125,15 @@ func Test_ca(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) - assert.Nil(t, os.Remove(crtF.Name())) - assert.Nil(t, os.Remove(keyF.Name())) + assert.NoError(t, err) + assert.NoError(t, os.Remove(crtF.Name())) + assert.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) + assert.NoError(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -141,20 +141,20 @@ func Test_ca(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) - assert.Len(t, lCrt.Networks(), 0) + assert.Empty(t, lCrt.Networks()) assert.True(t, lCrt.IsCA()) assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) - assert.Len(t, lCrt.UnsafeNetworks(), 0) + assert.Empty(t, lCrt.UnsafeNetworks()) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) assert.Equal(t, "", lCrt.Issuer()) @@ -166,7 +166,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, testpw)) + assert.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -174,7 +174,7 @@ func Test_ca(t *testing.T) { rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.Nil(t, err) + assert.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -184,8 +184,8 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Nil(t, err) - assert.Len(t, b, 0) + assert.NoError(t, err) + assert.Empty(t, b) assert.Len(t, lKey, 64) // test when reading passsword results in an error @@ -214,7 +214,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) + assert.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index fcfd77b..3427254 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -53,7 +53,7 @@ func Test_keygen(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write @@ -66,14 +66,14 @@ func Test_keygen(t *testing.T) { // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, keygen(args, ob, eb)) + assert.NoError(t, keygen(args, ob, eb)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -81,14 +81,14 @@ func Test_keygen(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 86795e4..77e98e6 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -58,7 +58,7 @@ func Test_printCert(t *testing.T) { ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") @@ -84,7 +84,7 @@ func Test_printCert(t *testing.T) { fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", @@ -169,7 +169,7 @@ func Test_printCert(t *testing.T) { fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 466cb8c..4b242a4 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -109,7 +109,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} @@ -133,7 +133,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} @@ -156,7 +156,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} @@ -210,7 +210,7 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) @@ -231,7 +231,7 @@ func Test_signCert(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + assert.NoError(t, err) os.Remove(keyF.Name()) // failed cert write @@ -245,14 +245,14 @@ func Test_signCert(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) + assert.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -260,14 +260,14 @@ func Test_signCert(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) @@ -295,15 +295,15 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root @@ -320,7 +320,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) @@ -335,7 +335,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) @@ -355,11 +355,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -374,7 +374,7 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, testpw)) + assert.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index d94bd1f..c2a9f55 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "crypto/rand" - "errors" "os" "testing" "time" @@ -57,7 +56,7 @@ func Test_verify(t *testing.T) { ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") @@ -84,7 +83,7 @@ func Test_verify(t *testing.T) { ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") @@ -108,7 +107,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) + assert.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -120,5 +119,5 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.Nil(t, err) + assert.NoError(t, err) } diff --git a/config/config_test.go b/config/config_test.go index c3a1a73..39301f9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -26,11 +26,11 @@ func TestConfig_Load(t *testing.T) { os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.Nil(t, err) + assert.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.Nil(t, c.Load(dir)) + assert.NoError(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ "inner": "override", @@ -67,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) { l := test.NewLogger() c := NewC(l) c.Settings["bool"] = true - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "true" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = false - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "false" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "Y" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "yEs" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "N" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "nO" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) } func TestConfig_HasChanged(t *testing.T) { @@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.Nil(t, err) + assert.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.Nil(t, c.Load(dir)) + assert.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) diff --git a/firewall_test.go b/firewall_test.go index 8d32369..92914af 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -68,53 +68,53 @@ func TestFirewall_AddRule(t *testing.T) { ti, err := netip.ParsePrefix("1.2.3.4/32") assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") assert.NoError(t, err) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions @@ -155,7 +155,7 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -174,28 +174,28 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -350,11 +350,11 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) @@ -428,8 +428,8 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -443,7 +443,7 @@ func TestFirewall_Drop3(t *testing.T) { // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } @@ -480,7 +480,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -493,7 +493,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -502,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -605,22 +605,22 @@ func Test_parsePort(t *testing.T) { s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + assert.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { @@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord @@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 7edc55b..4b898af 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -44,7 +44,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap - assert.Len(t, mainHM.Hosts, 0) + assert.Empty(t, mainHM.Hosts) // Confirm they are in the pending index list assert.Contains(t, blah.vpnIps, ip) diff --git a/header/header_test.go b/header/header_test.go index 765a006..1836a75 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -111,7 +111,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/lighthouse_test.go b/lighthouse_test.go index d5947aa..9e9ad53 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -42,7 +42,7 @@ func Test_lhStaticMapping(t *testing.T) { c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) diff --git a/outside_test.go b/outside_test.go index f197594..944bf16 100644 --- a/outside_test.go +++ b/outside_test.go @@ -63,7 +63,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) @@ -85,7 +85,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) @@ -134,7 +134,7 @@ func Test_newPacket_v6(t *testing.T) { } err = newPacket(buffer.Bytes(), true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -146,7 +146,7 @@ func Test_newPacket_v6(t *testing.T) { b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -158,7 +158,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -197,7 +197,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -207,7 +207,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -224,7 +224,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -234,7 +234,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -279,7 +279,7 @@ func Test_newPacket_v6(t *testing.T) { b = append(b, udpHeader...) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) diff --git a/overlay/route_test.go b/overlay/route_test.go index c60e4c2..4fa30af 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -18,8 +18,8 @@ func Test_parseRoutes(t *testing.T) { // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} @@ -30,8 +30,8 @@ func Test_parseRoutes(t *testing.T) { // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} @@ -93,7 +93,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -123,8 +123,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} @@ -135,8 +135,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} @@ -188,13 +188,13 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + assert.NoError(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + assert.NoError(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} @@ -228,7 +228,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, routes, 4) tested := 0 diff --git a/punchy_test.go b/punchy_test.go index bedd2b2..7918449 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -15,31 +15,31 @@ func TestNewPunchyFromConfig(t *testing.T) { // Test defaults p := NewPunchyFromConfig(l, c) - assert.Equal(t, false, p.GetPunch()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetPunch()) + assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, 5*time.Second, p.GetRespondDelay()) // punchy deprecation c.Settings["punchy"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} @@ -63,7 +63,7 @@ punchy: `)) p := NewPunchyFromConfig(l, c) assert.Equal(t, delay, p.GetDelay()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") assert.NoError(t, c.ReloadConfigString(` @@ -73,5 +73,5 @@ punchy: `)) p.reload(c, false) assert.Equal(t, newDelay, p.GetDelay()) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) } From 088af8edb264ec1a25d947e192da7938c48d18d4 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Mon, 10 Mar 2025 17:38:14 -0500 Subject: [PATCH 04/44] Enable running testifylint in CI (#1350) --- .github/workflows/test.yml | 10 +++ .golangci.yaml | 9 ++ allow_list_test.go | 13 +-- calculated_remote_test.go | 16 ++-- cert/ca_pool_test.go | 151 +++++++++++++++++---------------- cert/cert_v1_test.go | 34 ++++---- cert/cert_v2_test.go | 50 +++++------ cert/crypto_test.go | 15 ++-- cert/pem_test.go | 45 +++++----- cert/sign_test.go | 15 ++-- cmd/nebula-cert/ca_test.go | 37 ++++---- cmd/nebula-cert/keygen_test.go | 15 ++-- cmd/nebula-cert/main_test.go | 3 +- cmd/nebula-cert/print_test.go | 11 +-- cmd/nebula-cert/sign_test.go | 63 +++++++------- cmd/nebula-cert/verify_test.go | 17 ++-- config/config_test.go | 10 +-- connection_manager_test.go | 7 +- e2e/handshakes_test.go | 21 ++--- firewall_test.go | 150 ++++++++++++++++---------------- header/header_test.go | 3 +- lighthouse_test.go | 31 ++++--- outside_test.go | 57 +++++++------ overlay/route_test.go | 79 ++++++++--------- punchy_test.go | 5 +- 25 files changed, 451 insertions(+), 416 deletions(-) create mode 100644 .golangci.yaml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f3f2ed..b8a4f03 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,6 +31,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test @@ -109,6 +114,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..f792069 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,9 @@ +# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +linters: + # Disable all linters. + # Default: false + disable-all: true + # Enable specific linter + # https://golangci-lint.run/usage/linters/#enabled-by-default + enable: + - testifylint diff --git a/allow_list_test.go b/allow_list_test.go index 6d5e76b..d7d2c9a 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAllowListFromConfig(t *testing.T) { @@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") + require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") + require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "fd00:fd00::/16": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") + require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ @@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err = NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") + require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 066213e..6df893c 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") - assert.NoError(t, err) + require.NoError(t, err) expected, err := netip.ParseAddr("192.168.1.182") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) @@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 2f9255f..b0fdd5f 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCAPoolFromBytes(t *testing.T) { @@ -82,12 +83,12 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } p, err := NewCAPoolFromPEM([]byte(noNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) @@ -105,7 +106,7 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Len(t, ppppp.CAs, 1) } @@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { @@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { @@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { @@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { @@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { @@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { @@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { @@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index ea98b08..c687172 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -39,11 +39,11 @@ func TestCertificateV1_Marshal(t *testing.T) { } b, err := nc.Marshal() - assert.NoError(t, err) + require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -99,7 +99,7 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", @@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates @@ -186,7 +186,7 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } @@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV1(data, nil) - assert.EqualError(t, err, "encoded Details was nil") + require.EqualError(t, err, "encoded Details was nil") } func appendByteSlices(b ...[]byte) []byte { diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 6d55750..c84f8c9 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -49,7 +49,7 @@ func TestCertificateV2_Marshal(t *testing.T) { //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -114,14 +114,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.ErrorIs(t, err, ErrMissingDetails) + require.ErrorIs(t, err, ErrMissingDetails) rd, err := nc.details.Marshal() - assert.NoError(t, err) + require.NoError(t, err) nc.rawDetails = rd b, err = nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", @@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) ac, ok := c.(*certificateV2) require.True(t, ok) ac.curve = Curve(99) err = c.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) err = c.VerifyPrivateKey(Curve_P256, priv) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) aCa, ok := ca2.(*certificateV2) require.True(t, ok) aCa.curve = Curve(99) err = aCa.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") } func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { @@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) { func TestUnmarshalCertificateV2(t *testing.T) { data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) - assert.EqualError(t, err, "bad wire format") + require.EqualError(t, err, "bad wire format") } func TestCertificateV2_marshalForSigningStability(t *testing.T) { diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c43eed7..ee671c0 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/argon2" ) @@ -61,33 +62,33 @@ qrlJ69wer3ZUHFXA // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) // Fail due to invalid banner curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to invalid passphrase curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") + require.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) assert.Equal(t, []byte{}, rest) } @@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.NoError(t, err) + require.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, []byte{}, rest) - assert.NoError(t, err) + require.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } diff --git a/cert/pem_test.go b/cert/pem_test.go index 9ad8a69..6e49249 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshalCertificateFromPEM(t *testing.T) { @@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper certificate banner") + require.EqualError(t, err, "bytes did not contain a proper certificate banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { @@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { @@ -146,33 +147,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper private key banner") + require.EqualError(t, err, "bytes did not contain a proper private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPublicKeyFromPEM(t *testing.T) { @@ -202,7 +203,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key @@ -210,13 +211,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -225,7 +226,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalX25519PublicKey(t *testing.T) { @@ -260,14 +261,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) @@ -275,12 +276,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -288,5 +289,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } diff --git a/cert/sign_test.go b/cert/sign_test.go index 30d8480..e6f43cd 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCertificateV1_Sign(t *testing.T) { @@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } @@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) { } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.NoError(t, err) + require.NoError(t, err) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 71b69be..189fc02 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_caSummary(t *testing.T) { @@ -106,34 +107,34 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) - assert.NoError(t, os.Remove(crtF.Name())) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(crtF.Name())) + require.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -142,13 +143,13 @@ func Test_ca(t *testing.T) { lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Empty(t, lCrt.Networks()) @@ -166,7 +167,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, testpw)) + require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -174,7 +175,7 @@ func Test_ca(t *testing.T) { rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.NoError(t, err) + require.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -184,7 +185,7 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Len(t, lKey, 64) @@ -194,7 +195,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Error(t, ca(args, ob, eb, errpw)) + require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -204,7 +205,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -214,13 +215,13 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -229,7 +230,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 3427254..7eed5d2 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_keygenSummary(t *testing.T) { @@ -47,33 +48,33 @@ func Test_keygen(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write ob.Reset() eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, keygen(args, ob, eb)) + require.NoError(t, keygen(args, ob, eb)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -82,13 +83,13 @@ func Test_keygen(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index f332895..2e92e7e 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_help(t *testing.T) { @@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) { t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } - assert.EqualError(t, err, msg) + require.EqualError(t, err, msg) } func optionalPkcs11String(msg string) string { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 77e98e6..061e472 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_printSummary(t *testing.T) { @@ -52,20 +53,20 @@ func Test_printCert(t *testing.T) { err = printCert([]string{"-path", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs ob.Reset() @@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) { fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", @@ -169,7 +170,7 @@ func Test_printCert(t *testing.T) { fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4b242a4..b2bba76 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(keyF.Name()) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -261,13 +262,13 @@ func Test_signCert(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) @@ -295,7 +296,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -303,7 +304,7 @@ func Test_signCert(t *testing.T) { rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root @@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, testpw)) + require.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) { testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, testpw)) + require.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, nopw)) + require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, errpw)) + require.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index c2a9f55..acc9cca 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -50,20 +51,20 @@ func Test_verify(t *testing.T) { err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) @@ -77,20 +78,20 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -107,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.ErrorIs(t, err, cert.ErrSignatureMismatch) + require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -119,5 +120,5 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/config/config_test.go b/config/config_test.go index 39301f9..468c642 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ "inner": "override", @@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) diff --git a/connection_manager_test.go b/connection_manager_test.go index 8e2ef15..2c9baa1 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestLighthouse() *LightHouse { @@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { } caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) ncp := cert.NewCAPool() - assert.NoError(t, ncp.AddCA(caCert)) + require.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) tbs = &cert.TBSCertificate{ @@ -237,7 +238,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCrt, } peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 2e7e6e4..06f2a21 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -19,6 +19,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) for { @@ -987,9 +988,9 @@ func TestRehandshaking(t *testing.T) { r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var theirNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -997,7 +998,7 @@ func TestRehandshaking(t *testing.T) { "group": "new group", }} rc, err = yaml.Marshal(theirNewConfig) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") @@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) { "key": string(theirNextPrivKey), } rc, err := yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) for { @@ -1083,9 +1084,9 @@ func TestRehandshakingLoser(t *testing.T) { // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var myNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -1093,7 +1094,7 @@ func TestRehandshakingLoser(t *testing.T) { "group": "their new group", }} rc, err = yaml.Marshal(myNewConfig) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") diff --git a/firewall_test.go b/firewall_test.go index 92914af..8c2eeb0 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.OutRules) ti, err := netip.ParsePrefix("1.2.3.4/32") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr @@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -428,23 +428,23 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -480,29 +480,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -585,42 +585,42 @@ func BenchmarkLookup(b *testing.B) { func Test_parsePort(t *testing.T) { _, _, err := parsePort("") - assert.EqualError(t, err, "was not a number; ``") + require.EqualError(t, err, "was not a number; ``") _, _, err = parsePort(" ") - assert.EqualError(t, err, "was not a number; ` `") + require.EqualError(t, err, "was not a number; ` `") _, _, err = parsePort("-") - assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") + require.EqualError(t, err, "appears to be a range but could not be parsed; `-`") _, _, err = parsePort(" - ") - assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") + require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") _, _, err = parsePort("a-b") - assert.EqualError(t, err, "beginning range was not a number; `a`") + require.EqualError(t, err, "beginning range was not a number; `a`") _, _, err = parsePort("1-b") - assert.EqualError(t, err, "ending range was not a number; `b`") + require.EqualError(t, err, "ending range was not a number; `b`") s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { @@ -633,53 +633,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") + require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") + require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") + require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") + require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") + require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { @@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -767,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { @@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord @@ -793,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err = convertRule(l, c, "test", 1) assert.Equal(t, "", ob.String()) - assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") + require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright ob.Reset() @@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/header/header_test.go b/header/header_test.go index 1836a75..a7e5374 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type headerTest struct { @@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/lighthouse_test.go b/lighthouse_test.go index 9e9ad53..3b1295a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) { b := []byte{8, 129, 130, 132, 80, 16, 10} var m V4AddrPort err := m.Unmarshal(b) - assert.NoError(t, err) + require.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) @@ -42,14 +43,14 @@ func Test_lhStaticMapping(t *testing.T) { c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") + require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { @@ -71,19 +72,19 @@ func TestReloadLighthouseInterval(t *testing.T) { c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } @@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - if !assert.NoError(b, err) { - b.Fatal() - } + require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") @@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) } @@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) @@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} - assert.NoError(t, err) + require.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that @@ -290,7 +289,7 @@ func TestLighthouse_reload(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) nc := map[interface{}]interface{}{ "static_host_map": map[interface{}]interface{}{ @@ -298,11 +297,11 @@ func TestLighthouse_reload(t *testing.T) { }, } rc, err := yaml.Marshal(nc) - assert.NoError(t, err) + require.NoError(t, err) c.ReloadConfigString(string(rc)) err = lh.reload(c, false) - assert.NoError(t, err) + require.NoError(t, err) } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { diff --git a/outside_test.go b/outside_test.go index 944bf16..c63e57d 100644 --- a/outside_test.go +++ b/outside_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" ) @@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) { // length fails err := newPacket([]byte{}, true, p) - assert.ErrorIs(t, err, ErrPacketTooShort) + require.ErrorIs(t, err, ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) - assert.ErrorIs(t, err, ErrIPv4PacketTooShort) + require.ErrorIs(t, err, ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrUnknownIPVersion) + require.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) @@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) @@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) { FixLengths: false, } err := gopacket.SerializeLayers(buffer, opt, &ip) - assert.NoError(t, err) + require.NoError(t, err) err = newPacket(buffer.Bytes(), true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good ICMP packet ip = layers.IPv6{ @@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) { } err = newPacket(buffer.Bytes(), true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) { b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = 255 // 255 is a reserved protocol number err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good UDP packet ip = layers.IPv6{ @@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) { DstPort: layers.UDPPort(22), } err = udp.SetNetworkLayerForChecksum(&ip) - assert.NoError(t, err) + require.NoError(t, err) buffer.Clear() err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) @@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) { // Too short UDP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good TCP packet b[6] = byte(layers.IPProtocolTCP) // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) { // Too short TCP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good UDP packet with an AH header ip = layers.IPv6{ @@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) { b = append(b, udpHeader...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) { // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) } func Test_newPacket_ipv6Fragment(t *testing.T) { @@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment incoming err = newPacket(firstFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment outgoing err = newPacket(firstFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment incoming err = newPacket(secondFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment outgoing err = newPacket(secondFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Too short of a fragment packet err = newPacket(secondFrag[:len(secondFrag)-10], false, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) } func BenchmarkParseV6(b *testing.B) { diff --git a/overlay/route_test.go b/overlay/route_test.go index 4fa30af..8f2c094 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -8,84 +8,85 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.routes is not an array") + require.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.routes is invalid") + require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not present") + require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ @@ -93,7 +94,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -119,36 +120,36 @@ func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.unsafe_routes is not an array") + require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") + require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via for _, invalidValue := range []interface{}{ @@ -157,44 +158,44 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} @@ -206,19 +207,19 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ @@ -228,7 +229,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 4) tested := 0 @@ -260,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) - assert.NoError(t, err) + require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") - assert.NoError(t, err) + require.NoError(t, err) r, ok := routeTree.Lookup(ip) assert.True(t, ok) nip, err := netip.ParseAddr("192.168.0.1") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.0.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) nip, err = netip.ParseAddr("192.168.0.2") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.1.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/punchy_test.go b/punchy_test.go index 7918449..99d703d 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewPunchyFromConfig(t *testing.T) { @@ -56,7 +57,7 @@ func TestPunchy_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) delay, _ := time.ParseDuration("1m") - assert.NoError(t, c.LoadString(` + require.NoError(t, c.LoadString(` punchy: delay: 1m respond: false @@ -66,7 +67,7 @@ punchy: assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") - assert.NoError(t, c.ReloadConfigString(` + require.NoError(t, c.ReloadConfigString(` punchy: delay: 10m respond: true From 2fb018ced85be1f254de77eb1703584642aad49d Mon Sep 17 00:00:00 2001 From: Aleksandr Zykov Date: Wed, 12 Mar 2025 04:58:52 +0100 Subject: [PATCH 05/44] Fixed homebrew formula path (#1219) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56e4c9d..5eea0e2 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for $ sudo apk add nebula ``` -- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb) +- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb) ``` $ brew install nebula ``` From 1d3c85338c104a7869607b6c23272cec6026ea9e Mon Sep 17 00:00:00 2001 From: jampe Date: Wed, 12 Mar 2025 15:35:33 +0100 Subject: [PATCH 06/44] add so_mark sockopt support (#1331) --- examples/config.yml | 5 +++++ udp/udp_linux.go | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/examples/config.yml b/examples/config.yml index 1c3584e..aae0d98 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -144,6 +144,11 @@ listen: # valid values: always, never, private # This setting is reloadable. #send_recv_error: always + # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier. + # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes, + # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set. + # This setting is reloadable. + #so_mark: 0 # Routines is the number of thread pairs to run that consume from the tun and UDP queues. # Currently, this defaults to 1 which means we have 1 tun queue reader and 1 diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 32a567e..f1936b4 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -84,6 +84,10 @@ func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } +func (u *StdConn) SetSoMark(mark int) error { + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) +} + func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } @@ -92,6 +96,10 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } +func (u *StdConn) GetSoMark() (int, error) { + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { @@ -270,6 +278,22 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.write_buffer") } } + + b = c.GetInt("listen.so_mark", 0) + s, err := u.GetSoMark() + if b > 0 || (err == nil && s != 0) { + err := u.SetSoMark(b) + if err == nil { + s, err := u.GetSoMark() + if err == nil { + u.l.WithField("mark", s).Info("listen.so_mark was set") + } else { + u.l.WithError(err).Warn("Failed to get listen.so_mark") + } + } else { + u.l.WithError(err).Error("Failed to set listen.so_mark") + } + } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { From 50473bd2a893404de88841464781ac7deaba9ea9 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Wed, 12 Mar 2025 22:53:16 -0500 Subject: [PATCH 07/44] Update example config to listen on `::` by default (#1351) --- examples/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index aae0d98..4e7a4ae 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -126,8 +126,8 @@ lighthouse: # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "::" - host: 0.0.0.0 + # To listen on only ipv4, use "0.0.0.0" + host: "::" port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # default is 64, does not support reload From 3de36c99b6c7e304a463128ae9319d96bfd822e9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 14 Mar 2025 13:49:27 -0400 Subject: [PATCH 08/44] build with go1.24 (#1338) This doesn't change our go.mod, which still only requires go1.22 as a minimum. It only changes our builds to use go1.24 so we have the latest improvements. --- .github/workflows/gofmt.yml | 2 +- .github/workflows/release.yml | 6 +++--- .github/workflows/smoke.yml | 2 +- .github/workflows/test.yml | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index 20a39cf..288f32c 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 392f71b..f9df115 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -70,7 +70,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Import certificates diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 3f63008..fc654da 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b8a4f03..28f0590 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -60,7 +60,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -102,7 +102,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build nebula From f86953ca56e623ec7629ef7753024d8ced944a72 Mon Sep 17 00:00:00 2001 From: dioss-Machiel Date: Mon, 24 Mar 2025 23:15:59 +0100 Subject: [PATCH 09/44] Implement ECMP for unsafe_routes (#1332) --- examples/config.yml | 23 ++++++- inside.go | 95 +++++++++++++++++++++++--- overlay/device.go | 4 +- overlay/route.go | 78 ++++++++++++++++++---- overlay/route_test.go | 112 ++++++++++++++++++++++++++++++- overlay/tun_android.go | 5 +- overlay/tun_darwin.go | 9 +-- overlay/tun_disabled.go | 5 +- overlay/tun_freebsd.go | 7 +- overlay/tun_ios.go | 5 +- overlay/tun_linux.go | 91 +++++++++++++++++++------ overlay/tun_netbsd.go | 7 +- overlay/tun_openbsd.go | 7 +- overlay/tun_tester.go | 5 +- overlay/tun_windows.go | 15 +++-- overlay/user.go | 11 ++- routing/balance.go | 39 +++++++++++ routing/balance_test.go | 144 ++++++++++++++++++++++++++++++++++++++++ routing/gateway.go | 70 +++++++++++++++++++ routing/gateway_test.go | 34 ++++++++++ test/tun.go | 6 +- 21 files changed, 690 insertions(+), 82 deletions(-) create mode 100644 routing/balance.go create mode 100644 routing/balance_test.go create mode 100644 routing/gateway.go create mode 100644 routing/gateway_test.go diff --git a/examples/config.yml b/examples/config.yml index 4e7a4ae..3b7c38b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -239,7 +239,28 @@ tun: # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula - # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate + # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula + # NOTES: + # * You will only see a single gateway in the routing table if you are not on linux + # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights + # + # unsafe_routes: + # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # - gateway: 10.0.0.2 + # - gateway: 10.0.0.3 + # # Multiple gateways with a weight, this will balance traffic accordingly + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # weight: 10 + # - gateway: 10.0.0.2 + # weight: 5 + # + # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate + # `via`: single node or list of gateways to use for this route # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. diff --git a/inside.go b/inside.go index 9629947..0af350d 100644 --- a/inside.go +++ b/inside.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/routing" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) @@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } +// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established func (f *Interface) Handshake(vpnAddr netip.Addr) { - f.getOrHandshake(vpnAddr, nil) + f.getOrHandshakeNoRouting(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnAddr is not routable. +// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { _, found := f.myVpnNetworksTable.Lookup(vpnAddr) - if !found { - vpnAddr = f.inside.RouteFor(vpnAddr) - if !vpnAddr.IsValid() { - return nil, false - } + if found { + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) + } + + return nil, false +} + +// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. +func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + + destinationAddr := fwPacket.RemoteAddr + + hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) + + // Host is inside the mesh, no routing required + if hostinfo != nil { + return hostinfo, ready + } + + gateways := f.inside.RoutesFor(destinationAddr) + + switch len(gateways) { + case 0: + return nil, false + case 1: + // Single gateway route + return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback) + default: + // Multi gateway route, perform ECMP categorization + gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways) + + if !balancingOk { + // This happens if the gateway buckets were not calculated, this _should_ never happen + f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.") + } + + var handshakeInfoForChosenGateway *HandshakeHostInfo + var hhReceiver = func(hh *HandshakeHostInfo) { + handshakeInfoForChosenGateway = hh + } + + // Store the handshakeHostInfo for later. + // If this node is not reachable we will attempt other nodes, if none are reachable we will + // cache the packet for this gateway. + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready { + return hostinfo, true + } + + // It appears the selected gateway cannot be reached, find another gateway to fallback on. + // The current implementation breaks ECMP but that seems better than no connectivity. + // If ECMP is also required when a gateway is down then connectivity status + // for each gateway needs to be kept and the weights recalculated when they go up or down. + // This would also need to interact with unsafe_route updates through reloading the config or + // use of the use_system_route_table option + + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("destination", destinationAddr). + WithField("originalGateway", gatewayAddr). + Debugln("Calculated gateway for ECMP not available, attempting other gateways") + } + + for i := range gateways { + // Skip the gateway that failed previously + if gateways[i].Addr() == gatewayAddr { + continue + } + + // We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready { + return hostinfo, true + } + } + + // No gateways reachable, cache the packet in the originally chosen gateway + cacheCallback(handshakeInfoForChosenGateway) + return hostinfo, false } - return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { + hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) diff --git a/overlay/device.go b/overlay/device.go index da8cbe9..07146ab 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -3,6 +3,8 @@ package overlay import ( "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type Device interface { @@ -10,6 +12,6 @@ type Device interface { Activate() error Networks() []netip.Prefix Name() string - RouteFor(netip.Addr) netip.Addr + RoutesFor(netip.Addr) routing.Gateways NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 687cc11..12364ec 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -11,13 +11,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type Route struct { MTU int Metric int Cidr netip.Prefix - Via netip.Addr + Via routing.Gateways Install bool } @@ -47,15 +48,17 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { - routeTree := new(bart.Table[netip.Addr]) +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { + routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via.IsValid() { - routeTree.Insert(r.Cidr, r.Via) + gateways := r.Via + if len(gateways) > 0 { + routing.CalculateBucketsForGateways(gateways) + routeTree.Insert(r.Cidr, gateways) } } return routeTree, nil @@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) } - via, ok := rVia.(string) - if !ok { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) - } + var gateways routing.Gateways - viaVpnIp, err := netip.ParseAddr(via) - if err != nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + switch via := rVia.(type) { + case string: + viaIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + } + + gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} + + case []interface{}: + gateways = make(routing.Gateways, len(via)) + for ig, v := range via { + gatewayMap, ok := v.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) + } + + rGateway, ok := gatewayMap["gateway"] + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1) + } + + parsedGateway, ok := rGateway.(string) + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1) + } + + gatewayIp, err := netip.ParseAddr(parsedGateway) + if err != nil { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err) + } + + rGatewayWeight, ok := gatewayMap["weight"] + if !ok { + rGatewayWeight = 1 + } + + gatewayWeight, ok := rGatewayWeight.(int) + if !ok { + _, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32) + if err != nil { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1) + } + } + + if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight) + } + + gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight) + + } + + default: + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia) } rRoute, ok := m["route"] @@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { } r := Route{ - Via: viaVpnIp, + Via: gateways, MTU: mtu, Metric: metric, Install: install, diff --git a/overlay/route_test.go b/overlay/route_test.go index 8f2c094..eb5e914 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -158,15 +159,39 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) } + // Unparsable list of via + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") + // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + // unparsable gateway + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") + + // missing gateway element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") + + // unparsable weight element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") + // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) @@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) { nip, err := netip.ParseAddr("192.168.0.1") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.0.0.1") require.NoError(t, err) @@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) { nip, err = netip.ParseAddr("192.168.0.2") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.1.0.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } + +func Test_makeMultipathUnsafeRouteTree(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) + + c.Settings["tun"] = map[interface{}]interface{}{ + "unsafe_routes": []interface{}{ + map[interface{}]interface{}{ + "route": "192.168.86.0/24", + "via": "192.168.100.10", + }, + map[interface{}]interface{}{ + "route": "192.168.87.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.3", + }, + }, + }, + map[interface{}]interface{}{ + "route": "192.168.89.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + "weight": 10, + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + "weight": 5, + }, + }, + }, + }, + } + + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Len(t, routes, 3) + routeTree, err := makeRouteTree(l, routes, true) + require.NoError(t, err) + + ip, err := netip.ParseAddr("192.168.86.1") + require.NoError(t, err) + r, ok := routeTree.Lookup(ip) + assert.True(t, ok) + + nip, err := netip.ParseAddr("192.168.100.10") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("192.168.87.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) + + ip, err = netip.ParseAddr("192.168.89.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 72a6565..df1ed8d 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,6 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -21,7 +22,7 @@ type tun struct { fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1a02b49..d2b2896 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -28,7 +29,7 @@ type tun struct { vpnNetworks []netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - return netip.Addr{} + return routing.Gateways{} } // Get the LinkAddr for the interface of the given name @@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index cfbf17d..131879d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type disabledTun struct { @@ -43,8 +44,8 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (t *disabledTun) Networks() []netip.Prefix { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 69690e9..bcb82b3 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -20,6 +20,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -50,7 +51,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index e99d447..e51e112 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -23,7 +24,7 @@ type tun struct { io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 993bd4a..809536f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -34,7 +35,7 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeChan chan struct{} useSystemRoutes bool @@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -550,20 +551,7 @@ func (t *tun) watchRoutes() { }() } -func (t *tun) updateRoutes(r netlink.RouteUpdate) { - if r.Gw == nil { - // Not a gateway route, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") - return - } - - gwAddr, ok := netip.AddrFromSlice(r.Gw) - if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") - return - } - - gwAddr = gwAddr.Unmap() +func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { withinNetworks := false for i := range t.vpnNetworks { if t.vpnNetworks[i].Contains(gwAddr) { @@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { break } } - if !withinNetworks { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") + + return withinNetworks +} + +func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { + + var gateways routing.Gateways + + link, err := netlink.LinkByName(t.Device) + if err != nil { + t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") + return gateways + } + + // If this route is relevant to our interface and there is a gateway then add it + if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + gateways = append(gateways, routing.NewGateway(gwAddr, 1)) + } + } + } + + for _, p := range r.MultiPath { + // If this route is relevant to our interface and there is a gateway then add it + if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(p.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + // p.Hops+1 = weight of the route + gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) + } + } + } + } + + routing.CalculateBucketsForGateways(gateways) + return gateways +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + + gateways := t.getGatewaysFromRoute(&r.Route) + + if len(gateways) == 0 { + // No gateways relevant to our network, no routing changes required. + t.l.WithField("route", r).Debug("Ignoring route update, no gateways") return } @@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { newTree := t.routeTree.Load().Clone() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.Insert(dst, gwAddr) + t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + newTree.Insert(dst, gateways) } else { + t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") newTree.Delete(dst) - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } t.routeTree.Store(newTree) } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index f7586cb..847f1b5 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -31,7 +32,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index a2fd184..03fb3a0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -25,7 +26,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -158,7 +159,7 @@ func (t *tun) Activate() error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index cc3942f..b6712fb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -13,13 +13,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type TestTun struct { Device string vpnNetworks []netip.Prefix Routes []Route - routeTree *bart.Table[netip.Addr] + routeTree *bart.Table[routing.Gateways] l *logrus.Logger closed atomic.Bool @@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Lookup(ip) return r } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 289999d..1d66eac 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -31,7 +32,7 @@ type winTun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger tun *wintun.NativeTun @@ -147,13 +148,16 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + // Windows does not support multipath routes natively, so we install only a single route. + // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. + // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. + err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - err := luid.DeleteRoute(r.Cidr, r.Via) + // See comment on luid.AddRoute + err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/user.go b/overlay/user.go index ae665f3..8a56d66 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { @@ -38,9 +39,13 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } + +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { + return routing.Gateways{routing.NewGateway(ip, 1)} +} + func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/routing/balance.go b/routing/balance.go new file mode 100644 index 0000000..6f52497 --- /dev/null +++ b/routing/balance.go @@ -0,0 +1,39 @@ +package routing + +import ( + "net/netip" + + "github.com/slackhq/nebula/firewall" +) + +// Hashes the packet source and destination port and always returns a positive integer +// Based on 'Prospecting for Hash Functions' +// - https://nullprogram.com/blog/2018/07/31/ +// - https://github.com/skeeto/hash-prospector +// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 +func hashPacket(p *firewall.Packet) int { + x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + x ^= x >> 16 + x *= 0x21f0aaad + x ^= x >> 15 + x *= 0xd35a2d97 + x ^= x >> 15 + + return int(x) & 0x7FFFFFFF +} + +// For this function to work correctly it requires that the buckets for the gateways have been calculated +// If the contract is violated balancing will not work properly and the second return value will return false +func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) { + hash := hashPacket(fwPacket) + + for i := range gateways { + if hash <= gateways[i].BucketUpperBound() { + return gateways[i].Addr(), true + } + } + + // If you land here then the buckets for the gateways are not properly calculated + // Fallback to random routing and let the caller know + return gateways[hash%len(gateways)].Addr(), false +} diff --git a/routing/balance_test.go b/routing/balance_test.go new file mode 100644 index 0000000..bbfcb22 --- /dev/null +++ b/routing/balance_test.go @@ -0,0 +1,144 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" +) + +func TestPacketsAreBalancedEqually(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + gw3Addr := netip.MustParseAddr("1.0.0.3") + + gateways = append(gateways, NewGateway(gw1Addr, 1)) + gateways = append(gateways, NewGateway(gw2Addr, 1)) + gateways = append(gateways, NewGateway(gw3Addr, 1)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + gw3count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + case gw3Addr: + gw3count += 1 + } + + } + + // Assert packets are balanced, allow variation of up to 100 packets per gateway + assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + +} + +func TestPacketsAreBalancedByPriority(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + iterationCountAsFloat := float32(iterationCount) + + assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count) + assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count) +} + +func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) { + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + iterationCount := uint16(65535) + gw1count := 0 + gw2count := 0 + + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.False(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + assert.Equal(t, int(iterationCount), (gw1count + gw2count)) + assert.NotEqual(t, 0, gw1count) + assert.NotEqual(t, 0, gw2count) + +} diff --git a/routing/gateway.go b/routing/gateway.go new file mode 100644 index 0000000..59d38a9 --- /dev/null +++ b/routing/gateway.go @@ -0,0 +1,70 @@ +package routing + +import ( + "fmt" + "net/netip" +) + +const ( + // Sentinal value + BucketNotCalculated = -1 +) + +type Gateways []Gateway + +func (g Gateways) String() string { + str := "" + for i, gw := range g { + str += gw.String() + if i < len(g)-1 { + str += ", " + } + } + return str +} + +type Gateway struct { + addr netip.Addr + weight int + bucketUpperBound int +} + +func NewGateway(addr netip.Addr, weight int) Gateway { + return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated} +} + +func (g *Gateway) BucketUpperBound() int { + return g.bucketUpperBound +} + +func (g *Gateway) Addr() netip.Addr { + return g.addr +} + +func (g *Gateway) String() string { + return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight) +} + +// Divide and round to nearest integer +func divideAndRound(v uint64, d uint64) uint64 { + var tmp uint64 = v + d/2 + return tmp / d +} + +// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel. +// After this function returns each gateway will have a +// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX) +func CalculateBucketsForGateways(gateways []Gateway) { + + var totalWeight int = 0 + for i := range gateways { + totalWeight += gateways[i].weight + } + + var loopWeight int = 0 + for i := range gateways { + loopWeight += gateways[i].weight + gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1 + } + +} diff --git a/routing/gateway_test.go b/routing/gateway_test.go new file mode 100644 index 0000000..8ae78f3 --- /dev/null +++ b/routing/gateway_test.go @@ -0,0 +1,34 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRebalance3_2Split(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX +} + +func TestRebalanceEqualSplit(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3 + assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX +} diff --git a/test/tun.go b/test/tun.go index b29d61c..ca65805 100644 --- a/test/tun.go +++ b/test/tun.go @@ -4,12 +4,14 @@ import ( "errors" "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type NoopTun struct{} -func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (NoopTun) Activate() error { From 4444ed166ac163bcf4296d62d826c06b3376957b Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Tue, 25 Mar 2025 16:08:36 -0500 Subject: [PATCH 10/44] Add `certVersion` field to logs when logging the cert name in handshakes (#1359) --- handshake_ix.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/handshake_ix.go b/handshake_ix.go index daea526..0783999 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -71,7 +71,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -185,6 +186,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet var vpnAddrs []netip.Addr var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -194,6 +196,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if found { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") @@ -212,6 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if len(vpnAddrs) == 0 { f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") @@ -231,6 +235,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") @@ -253,6 +258,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -264,6 +270,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if hs.Details.Cert == nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -281,6 +288,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") @@ -292,6 +300,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") @@ -299,6 +308,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } else if dKey == nil || eKey == nil { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") @@ -366,6 +376,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // This means there was an existing tunnel and this handshake was older than the one we are currently based on f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("fingerprint", fingerprint). @@ -381,6 +392,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -393,6 +405,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // And we forget to update it here f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -409,6 +422,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -417,6 +431,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } else { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -435,6 +450,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -539,6 +555,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -573,6 +590,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if len(vpnAddrs) == 0 { f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") @@ -582,7 +600,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Ensure the right host responded if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("udpAddr", addr).WithField("certName", certName). + WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -618,6 +638,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha duration := time.Since(hh.startTime).Nanoseconds() f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). From 75faa5f2e5f551e21fcb75a9aeb3805366f30d90 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:05:07 -0400 Subject: [PATCH 11/44] Bump golang.org/x/net in the golang-x-dependencies group (#1370) Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/net` from 0.37.0 to 0.38.0 - [Commits](https://github.com/golang/net/compare/v0.37.0...v0.38.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index de09c18..3b13170 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/slackhq/nebula go 1.23.6 -toolchain go1.23.7 +toolchain go1.24.1 require ( dario.cat/mergo v1.0.1 @@ -26,7 +26,7 @@ require ( github.com/vishvananda/netlink v1.3.0 golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.37.0 + golang.org/x/net v0.38.0 golang.org/x/sync v0.12.0 golang.org/x/sys v0.31.0 golang.org/x/term v0.30.0 diff --git a/go.sum b/go.sum index 11f57c7..78f2671 100644 --- a/go.sum +++ b/go.sum @@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From 879852c32a385ac5059af91d89615178fcef532c Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 31 Mar 2025 16:08:34 -0400 Subject: [PATCH 12/44] upgrade to yaml.v3 (#1148) * upgrade to yaml.v3 The main nice fix here is that maps unmarshal into `map[string]any` instead of `map[any]any`, so it cleans things up a bit. * add config.AsBool Since yaml.v3 doesn't automatically convert yes to bool now, for backwards compat * use type aliases for m * more cleanup * more cleanup * more cleanup * go mod cleanup --- allow_list.go | 38 ++++++----------- allow_list_test.go | 24 +++++------ cert/cert_v1.go | 2 +- cmd/nebula-cert/main.go | 2 +- config/config.go | 50 ++++++++++++++-------- config/config_test.go | 38 ++++++++--------- control_test.go | 2 +- dns_server_test.go | 16 +++---- e2e/handshakes_test.go | 6 +-- e2e/helpers_test.go | 4 +- firewall.go | 10 ++--- firewall/packet.go | 2 +- firewall_test.go | 52 +++++++++++------------ go.mod | 3 +- go.sum | 2 - header/header.go | 2 +- lighthouse.go | 6 +-- lighthouse_test.go | 30 +++++++------- main.go | 4 +- overlay/route.go | 8 ++-- overlay/route_test.go | 76 +++++++++++++++++----------------- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 2 +- overlay/tun_linux.go | 2 +- overlay/tun_netbsd.go | 2 +- overlay/tun_openbsd.go | 2 +- overlay/tun_windows.go | 2 +- punchy_test.go | 8 ++-- service/service_test.go | 4 +- ssh.go | 92 ++++++++++++++++++++--------------------- sshd/command.go | 10 ++--- sshd/server.go | 2 +- sshd/session.go | 2 +- test/assert.go | 2 +- util/error.go | 4 +- util/error_test.go | 2 +- 36 files changed, 257 insertions(+), 258 deletions(-) diff --git a/allow_list.go b/allow_list.go index cfdd983..cba56fc 100644 --- a/allow_list.go +++ b/allow_list.go @@ -36,7 +36,7 @@ type AllowListNameRule struct { func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) { var nameRules []AllowListNameRule - handleKey := func(key string, value interface{}) (bool, error) { + handleKey := func(key string, value any) (bool, error) { if key == "interfaces" { var err error nameRules, err = getAllowListInterfaces(k, value) @@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. -func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { +func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { r := c.Get(k) if r == nil { return nil, nil @@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. -func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { - rawMap, ok := raw.(map[interface{}]interface{}) +func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { + rawMap, ok := raw.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } @@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - + for rawCIDR, rawValue := range rawMap { if handleKey != nil { handled, err := handleKey(rawCIDR, rawValue) if err != nil { @@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in } } - value, ok := rawValue.(bool) + value, ok := config.AsBool(rawValue) if !ok { return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } @@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return &AllowList{cidrTree: tree}, nil } -func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { +func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) { var nameRules []AllowListNameRule - rawRules, ok := v.(map[interface{}]interface{}) + rawRules, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) } firstEntry := true var allValues bool - for rawName, rawAllow := range rawRules { - name, ok := rawName.(string) - if !ok { - return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName) - } - allow, ok := rawAllow.(bool) + for name, rawAllow := range rawRules { + allow, ok := config.AsBool(rawAllow) if !ok { return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) } @@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error remoteAllowRanges := new(bart.Table[*AllowList]) - rawMap, ok := value.(map[interface{}]interface{}) + rawMap, ok := value.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) } - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - + for rawCIDR, rawValue := range rawMap { allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) if err != nil { return nil, err diff --git a/allow_list_test.go b/allow_list_test.go index d7d2c9a..6135f36 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -15,27 +15,27 @@ import ( func TestNewAllowListFromConfig(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -45,7 +45,7 @@ func TestNewAllowListFromConfig(t *testing.T) { r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -55,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) { assert.NotNil(t, r) } - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -70,16 +70,16 @@ func TestNewAllowListFromConfig(t *testing.T) { // Test interface names - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: "foo", }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: false, `eth.*`: true, }, @@ -87,8 +87,8 @@ func TestNewAllowListFromConfig(t *testing.T) { lr, err = NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: false, }, } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 6bb146f..71d36eb 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -41,7 +41,7 @@ type detailsV1 struct { curve Curve } -type m map[string]interface{} +type m = map[string]any func (c *certificateV1) Version() Version { return Version1 diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index b803d30..c88626f 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -17,7 +17,7 @@ func (he *helpError) Error() string { return he.s } -func newHelpErrorf(s string, v ...interface{}) error { +func newHelpErrorf(s string, v ...any) error { return &helpError{s: fmt.Sprintf(s, v...)} } diff --git a/config/config.go b/config/config.go index 1aea832..b1531e9 100644 --- a/config/config.go +++ b/config/config.go @@ -17,14 +17,14 @@ import ( "dario.cat/mergo" "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) type C struct { path string files []string - Settings map[interface{}]interface{} - oldSettings map[interface{}]interface{} + Settings map[string]any + oldSettings map[string]any callbacks []func(*C) l *logrus.Logger reloadLock sync.Mutex @@ -32,7 +32,7 @@ type C struct { func NewC(l *logrus.Logger) *C { return &C{ - Settings: make(map[interface{}]interface{}), + Settings: make(map[string]any), l: l, } } @@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool { } var ( - nv interface{} - ov interface{} + nv any + ov any ) if k == "" { @@ -147,7 +147,7 @@ func (c *C) ReloadConfig() { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string { return d } - rv, ok := r.([]interface{}) + rv, ok := r.([]any) if !ok { return d } @@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string { } // GetMap will get the map for k or return the default d if not found or invalid -func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { +func (c *C) GetMap(k string, d map[string]any) map[string]any { r := c.Get(k) if r == nil { return d } - v, ok := r.(map[interface{}]interface{}) + v, ok := r.(map[string]any) if !ok { return d } @@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool { return v } +func AsBool(v any) (value bool, ok bool) { + switch x := v.(type) { + case bool: + return x, true + case string: + switch x { + case "y", "yes": + return true, true + case "n", "no": + return false, true + } + } + + return false, false +} + // GetDuration will get the duration for k or return the default d if not found or invalid func (c *C) GetDuration(k string, d time.Duration) time.Duration { r := c.GetString(k, "") @@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration { return v } -func (c *C) Get(k string) interface{} { +func (c *C) Get(k string) any { return c.get(k, c.Settings) } @@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool { return c.get(k, c.Settings) != nil } -func (c *C) get(k string, v interface{}) interface{} { +func (c *C) get(k string, v any) any { parts := strings.Split(k, ".") for _, p := range parts { - m, ok := v.(map[interface{}]interface{}) + m, ok := v.(map[string]any) if !ok { return nil } @@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error { } func (c *C) parseRaw(b []byte) error { - var m map[interface{}]interface{} + var m map[string]any err := yaml.Unmarshal(b, &m) if err != nil { @@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error { } func (c *C) parse() error { - var m map[interface{}]interface{} + var m map[string]any for _, path := range c.files { b, err := os.ReadFile(path) @@ -366,7 +382,7 @@ func (c *C) parse() error { return err } - var nm map[interface{}]interface{} + var nm map[string]any err = yaml.Unmarshal(b, &nm) if err != nil { return err diff --git a/config/config_test.go b/config/config_test.go index 468c642..ec5a4b0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func TestConfig_Load(t *testing.T) { @@ -19,7 +19,7 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}") // simple multi config merge c = NewC(l) @@ -31,8 +31,8 @@ func TestConfig_Load(t *testing.T) { os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) require.NoError(t, c.Load(dir)) - expected := map[interface{}]interface{}{ - "outer": map[interface{}]interface{}{ + expected := map[string]any{ + "outer": map[string]any{ "inner": "override", }, "new": "hi", @@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) { l := test.NewLogger() // test simple type c := NewC(l) - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} + c.Settings["firewall"] = map[string]any{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) // test complex type - inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}} - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner} + inner := []map[string]any{{"port": "1", "code": "2"}} + c.Settings["firewall"] = map[string]any{"outbound": inner} assert.EqualValues(t, inner, c.Get("firewall.outbound")) // test missing @@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) { l := test.NewLogger() c := NewC(l) - c.Settings["slice"] = []interface{}{"one", "two"} + c.Settings["slice"] = []any{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } @@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) { // Test key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "no"} + c.oldSettings = map[string]any{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "hi"} + c.oldSettings = map[string]any{"test": "hi"} assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("")) } @@ -184,11 +184,11 @@ firewall: `), } - var m map[any]any + var m map[string]any // merge the same way config.parse() merges for _, b := range configs { - var nm map[any]any + var nm map[string]any err := yaml.Unmarshal(b, &nm) require.NoError(t, err) @@ -205,15 +205,15 @@ firewall: t.Logf("Merged Config as YAML:\n%s", mYaml) // If a bug is present, some items might be replaced instead of merged like we expect - expected := map[any]any{ - "firewall": map[any]any{ + expected := map[string]any{ + "firewall": map[string]any{ "inbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "icmp"}, - map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, - map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, + map[string]any{"host": "any", "port": "any", "proto": "icmp"}, + map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, + map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, "outbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "any"}}}, - "listen": map[any]any{ + map[string]any{"host": "any", "port": "any", "proto": "any"}}}, + "listen": map[string]any{ "host": "0.0.0.0", "port": 4242, }, diff --git a/control_test.go b/control_test.go index 6ce7083..de85fee 100644 --- a/control_test.go +++ b/control_test.go @@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }) } -func assertFields(t *testing.T, expected []string, actualStruct interface{}) { +func assertFields(t *testing.T, expected []string, actualStruct any) { val := reflect.ValueOf(actualStruct).Elem() fields := make([]string, val.NumField()) for i := 0; i < val.NumField(); i++ { diff --git a/dns_server_test.go b/dns_server_test.go index f4643a3..356e589 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) { c := config.NewC(nil) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "0.0.0.0", "port": "1", }, } assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "::", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::]", "port": "1", }, @@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) { assert.Equal(t, "[::]:1", getDnsServerAddr(c)) // Make sure whitespace doesn't mess us up - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::] ", "port": "1", }, diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 06f2a21..bc080ce 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,7 +20,7 @@ import ( "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func BenchmarkHotPath(b *testing.B) { @@ -991,7 +991,7 @@ func TestRehandshaking(t *testing.T) { require.NoError(t, err) var theirNewConfig m require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) - theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall := theirNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", @@ -1087,7 +1087,7 @@ func TestRehandshakingLoser(t *testing.T) { require.NoError(t, err) var myNewConfig m require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) - theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall := myNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index e1b7ac2..a63b3d0 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -22,10 +22,10 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { diff --git a/firewall.go b/firewall.go index e9f454d..e730114 100644 --- a/firewall.go +++ b/firewall.go @@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } - rs, ok := r.([]interface{}) + rs, ok := r.([]any) if !ok { return fmt.Errorf("%s failed to parse, should be an array of rules", table) } @@ -918,15 +918,15 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { +func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { r := rule{} - m, ok := p.(map[interface{}]interface{}) + m, ok := p.(map[string]any) if !ok { return r, errors.New("could not parse rule") } - toString := func(k string, m map[interface{}]interface{}) string { + toString := func(k string, m map[string]any) string { v, ok := m[k] if !ok { return "" @@ -944,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er r.CASha = toString("ca_sha", m) // Make sure group isn't an array - if v, ok := m["group"].([]interface{}); ok { + if v, ok := m["group"].([]any); ok { if len(v) > 1 { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } diff --git a/firewall/packet.go b/firewall/packet.go index 1d8f12a..40c7fc5 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -6,7 +6,7 @@ import ( "net/netip" ) -type m map[string]interface{} +type m = map[string]any const ( ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever diff --git a/firewall_test.go b/firewall_test.go index 8c2eeb0..c90fb20 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -631,53 +631,53 @@ func TestNewFirewallFromConfig(t *testing.T) { require.NoError(t, err) conf := config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} + conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } @@ -687,28 +687,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding tcp rule conf := config.NewC(l) mf := &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) @@ -716,49 +716,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) @@ -766,7 +766,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } @@ -776,8 +776,8 @@ func TestFirewall_convertRule(t *testing.T) { l.SetOutput(ob) // Ensure group array of 1 is converted and a warning is printed - c := map[interface{}]interface{}{ - "group": []interface{}{"group1"}, + c := map[string]any{ + "group": []any{"group1"}, } r, err := convertRule(l, c, "test", 1) @@ -787,8 +787,8 @@ func TestFirewall_convertRule(t *testing.T) { // Ensure group array of > 1 is errord ob.Reset() - c = map[interface{}]interface{}{ - "group": []interface{}{"group1", "group2"}, + c = map[string]any{ + "group": []any{"group1", "group2"}, } r, err = convertRule(l, c, "test", 1) @@ -797,7 +797,7 @@ func TestFirewall_convertRule(t *testing.T) { // Make sure a well formed group is alright ob.Reset() - c = map[interface{}]interface{}{ + c = map[string]any{ "group": "group1", } diff --git a/go.mod b/go.mod index 3b13170..bbd5d8b 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.36.5 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) @@ -53,5 +53,4 @@ require ( golang.org/x/mod v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 78f2671..8237bfa 100644 --- a/go.sum +++ b/go.sum @@ -251,8 +251,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header/header.go b/header/header.go index 50b7d62..f22509b 100644 --- a/header/header.go +++ b/header/header.go @@ -19,7 +19,7 @@ import ( // |-----------------------------------------------------------------------| // | payload... | -type m map[string]interface{} +type m = map[string]any const ( Version uint8 = 1 diff --git a/lighthouse.go b/lighthouse.go index ce37023..f13afd3 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -422,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return err } - shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) + shm := c.GetMap("static_host_map", map[string]any{}) i := 0 for k, v := range shm { @@ -436,9 +436,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } - vals, ok := v.([]interface{}) + vals, ok := v.([]any) if !ok { - vals = []interface{}{v} + vals = []any{v} } remoteAddrs := []string{} for _, v := range vals { diff --git a/lighthouse_test.go b/lighthouse_test.go index 3b1295a..6a541c2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -14,7 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func TestOldIPv4Only(t *testing.T) { @@ -40,15 +40,15 @@ func Test_lhStaticMapping(t *testing.T) { lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -65,12 +65,12 @@ func TestReloadLighthouseInterval(t *testing.T) { lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "hosts": []interface{}{lh1}, + c.Settings["lighthouse"] = map[string]any{ + "hosts": []any{lh1}, "interval": "1s", } - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -192,8 +192,8 @@ func TestLighthouse_Memory(t *testing.T) { theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Table[struct{}]) @@ -277,8 +277,8 @@ func TestLighthouse_Memory(t *testing.T) { func TestLighthouse_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Table[struct{}]) @@ -291,9 +291,9 @@ func TestLighthouse_reload(t *testing.T) { lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) - nc := map[interface{}]interface{}{ - "static_host_map": map[interface{}]interface{}{ - "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + nc := map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, }, } rc, err := yaml.Marshal(nc) diff --git a/main.go b/main.go index 7e94c32..b278fa6 100644 --- a/main.go +++ b/main.go @@ -13,10 +13,10 @@ import ( "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/overlay/route.go b/overlay/route.go index 12364ec..360921f 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -72,7 +72,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.routes is not an array") } @@ -83,7 +83,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) } @@ -151,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.unsafe_routes is not an array") } @@ -162,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) } diff --git a/overlay/route_test.go b/overlay/route_test.go index eb5e914..6b5ae2e 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -24,75 +24,75 @@ func Test_parseRoutes(t *testing.T) { assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} + c.Settings["tun"] = map[string]any{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} + c.Settings["tun"] = map[string]any{"routes": []any{}} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} + c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ - map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, - map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"routes": []any{ + map[string]any{"mtu": "9000", "route": "10.0.0.0/29"}, + map[string]any{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) @@ -129,34 +129,34 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} + c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via - for _, invalidValue := range []interface{}{ + for _, invalidValue := range []any{ 127, false, nil, 1.0, []string{"1", "2"}, } { - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) @@ -169,7 +169,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") // unparsable via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") @@ -193,65 +193,65 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // above network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) @@ -288,9 +288,9 @@ func Test_makeRouteTree(t *testing.T) { n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, - map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"}, + map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index d2b2896..7f6ba4f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -394,7 +394,7 @@ func (t *tun) addRoutes(logErrors bool) error { t.l.WithField("route", r.Cidr). Warnf("unable to add unsafe_route, identical route already exists") } else { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bcb82b3..2a89cbc 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -271,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 809536f..7d19c85 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -464,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := netlink.RouteReplace(&nr) if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 847f1b5..5ff9b0f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -206,7 +206,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 03fb3a0..67a9a5f 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -175,7 +175,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 1d66eac..7aac128 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -159,7 +159,7 @@ func (t *winTun) addRoutes(logErrors bool) error { // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) continue diff --git a/punchy_test.go b/punchy_test.go index 99d703d..56dd1c2 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -27,7 +27,7 @@ func TestNewPunchyFromConfig(t *testing.T) { assert.True(t, p.GetPunch()) // punchy.punch - c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} + c.Settings["punchy"] = map[string]any{"punch": true} p = NewPunchyFromConfig(l, c) assert.True(t, p.GetPunch()) @@ -37,18 +37,18 @@ func TestNewPunchyFromConfig(t *testing.T) { assert.True(t, p.GetRespond()) // punchy.respond - c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} + c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) assert.True(t, p.GetRespond()) // punchy.delay - c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} + c.Settings["punchy"] = map[string]any{"delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay - c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"} + c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } diff --git a/service/service_test.go b/service/service_test.go index 613758e..b9810cd 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -13,10 +13,10 @@ import ( "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) diff --git a/ssh.go b/ssh.go index 203166c..9a26c29 100644 --- a/ssh.go +++ b/ssh.go @@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro } rawKeys := c.Get("sshd.authorized_users") - keys, ok := rawKeys.([]interface{}) + keys, ok := rawKeys.([]any) if ok { for _, rk := range keys { - kDef, ok := rk.(map[interface{}]interface{}) + kDef, ok := rk.(map[string]any) if !ok { l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") continue @@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro continue } - case []interface{}: + case []any: for _, subK := range v { sk, ok := subK.(string) if !ok { @@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.hostMap, fs, w) }, }) @@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-pending-hostmap", ShortDescription: "List all handshaking hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-lighthouse-addrmap", ShortDescription: "List all lighthouse map entries", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "reload", ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshReload(c, w) }, }) @@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "stop-cpu-profile", ShortDescription: "Stops a cpu profile and writes output to the previously provided file", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { pprof.StopCPUProfile() return w.WriteLine("If a CPU profile was running it is now stopped") }, @@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogLevel(l, fs, a, w) }, }) @@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-format", ShortDescription: "Gets or sets the current log format", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogFormat(l, fs, a, w) }, }) @@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "version", ShortDescription: "Prints the currently running version of nebula", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshVersion(f, fs, a, w) }, }) @@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "device-info", ShortDescription: "Prints information about the network device.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshDeviceInfoFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshDeviceInfo(f, fs, w) }, }) @@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json") @@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintCert(f, fs, a, w) }, }) @@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", ShortDescription: "Prints json details about a tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintTunnel(f, fs, a, w) }, }) @@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-relays", ShortDescription: "Prints json details about all relay info", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintRelays(f, fs, a, w) }, }) @@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshChangeRemote(f, fs, a, w) }, }) @@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", ShortDescription: "Closes a tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCloseTunnel(f, fs, a, w) }, }) @@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter Name: "create-tunnel", ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCreateTunnelFlags{} fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCreateTunnel(f, fs, a, w) }, }) @@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter Name: "query-lighthouse", ShortDescription: "Query the lighthouses for the provided vpn address", Help: "This command is asynchronous. Only currently known udp addresses will be printed.", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil @@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er return nil } -func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { +func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil @@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr return nil } -func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err @@ -527,11 +527,11 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("%s", ifce.version)) } -func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No vpn address was provided") } @@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return json.NewEncoder(w.GetWriter()).Encode(cm) } -func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { return nil @@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("Closed") } -func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { return nil @@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Created") } -func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { return nil @@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Changed") } -func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -696,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error { +func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { rate := runtime.SetMutexProfileFraction(-1) return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) @@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -735,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } @@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } -func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } @@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } -func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { return nil @@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(cert.String()) } -func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) @@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return nil } -func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { return nil @@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } -func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { +func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error { data := struct { Name string `json:"name"` diff --git a/sshd/command.go b/sshd/command.go index 66646a6..7323d12 100644 --- a/sshd/command.go +++ b/sshd/command.go @@ -12,7 +12,7 @@ import ( // CommandFlags is a function called before help or command execution to parse command line flags // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags -type CommandFlags func() (*flag.FlagSet, interface{}) +type CommandFlags func() (*flag.FlagSet, any) // CommandCallback is the function called when your command should execute. // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved @@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{}) // w is the writer to use when sending messages back to the client. // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user // where appropriate -type CommandCallback func(fs interface{}, a []string, w StringWriter) error +type CommandCallback func(fs any, a []string, w StringWriter) error type Command struct { Name string @@ -34,7 +34,7 @@ type Command struct { func execCommand(c *Command, args []string, w StringWriter) error { var ( fl *flag.FlagSet - fs interface{} + fs any ) if c.Flags != nil { @@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { func matchCommand(c *radix.Tree, cmd string) []string { cmds := make([]string, 0) - c.WalkPrefix(cmd, func(found string, v interface{}) bool { + c.WalkPrefix(cmd, func(found string, v any) bool { cmds = append(cmds, found) return false }) @@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string { func allCommands(c *radix.Tree) []*Command { cmds := make([]*Command, 0) - c.WalkPrefix("", func(found string, v interface{}) bool { + c.WalkPrefix("", func(found string, v any) bool { cmd, ok := v.(*Command) if ok { cmds = append(cmds, cmd) diff --git a/sshd/server.go b/sshd/server.go index c151f91..a8b60ba 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s.RegisterCommand(&Command{ Name: "help", ShortDescription: "prints available commands or help for specific usage info", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { return helpCallback(s.commands, args, w) }, }) diff --git a/sshd/session.go b/sshd/session.go index 7c5869e..03b20cd 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New s.commands.Insert("logout", &Command{ Name: "logout", ShortDescription: "Ends the current session", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { s.Close() return nil }, diff --git a/test/assert.go b/test/assert.go index d34252e..1856877 100644 --- a/test/assert.go +++ b/test/assert.go @@ -13,7 +13,7 @@ import ( // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory // There is currently a special case for `time.loc` (as this code traverses into unexported fields) -func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { +func AssertDeepCopyEqual(t *testing.T, a any, b any) { v1 := reflect.ValueOf(a) v2 := reflect.ValueOf(b) diff --git a/util/error.go b/util/error.go index d7710f9..814c77a 100644 --- a/util/error.go +++ b/util/error.go @@ -9,11 +9,11 @@ import ( type ContextualError struct { RealError error - Fields map[string]interface{} + Fields map[string]any Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { +func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError { return &ContextualError{Context: msg, Fields: fields, RealError: realError} } diff --git a/util/error_test.go b/util/error_test.go index 5041f82..692c184 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -type m map[string]interface{} +type m = map[string]any type TestLogWriter struct { Logs []string From 36bc9dd26134e31b6893b431e92a5d37e58711e0 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 1 Apr 2025 09:49:26 -0400 Subject: [PATCH 13/44] fix parseUnsafeRoutes for yaml.v3 (#1371) We switched to yaml.v3 with #1148, but missed this spot that was still casting into `map[any]any` when yaml.v3 makes it `map[string]any`. Also clean up a few more `interface{}` that were added as we changed them all to `any` with #1148. --- overlay/route.go | 4 ++-- overlay/route_test.go | 32 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/overlay/route.go b/overlay/route.go index 360921f..6198958 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -215,10 +215,10 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} - case []interface{}: + case []any: gateways = make(routing.Gateways, len(via)) for ig, v := range via { - gatewayMap, ok := v.(map[interface{}]interface{}) + gatewayMap, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) } diff --git a/overlay/route_test.go b/overlay/route_test.go index 6b5ae2e..9a959a5 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -163,7 +163,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { } // Unparsable list of via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") @@ -175,19 +175,19 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // unparsable gateway - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") // missing gateway element - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") // unparsable weight element - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") @@ -328,34 +328,34 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) - c.Settings["tun"] = map[interface{}]interface{}{ - "unsafe_routes": []interface{}{ - map[interface{}]interface{}{ + c.Settings["tun"] = map[string]any{ + "unsafe_routes": []any{ + map[string]any{ "route": "192.168.86.0/24", "via": "192.168.100.10", }, - map[interface{}]interface{}{ + map[string]any{ "route": "192.168.87.0/24", - "via": []interface{}{ - map[interface{}]interface{}{ + "via": []any{ + map[string]any{ "gateway": "10.0.0.1", }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.2", }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.3", }, }, }, - map[interface{}]interface{}{ + map[string]any{ "route": "192.168.89.0/24", - "via": []interface{}{ - map[interface{}]interface{}{ + "via": []any{ + map[string]any{ "gateway": "10.0.0.1", "weight": 10, }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.2", "weight": 5, }, From d2adebf26daed3d29ac4a2c664de999dc8c79fac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:24:19 -0400 Subject: [PATCH 14/44] Bump golangci/golangci-lint-action from 6 to 7 (#1361) * Bump golangci/golangci-lint-action from 6 to 7 Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 6 to 7. - [Release notes](https://github.com/golangci/golangci-lint-action/releases) - [Commits](https://github.com/golangci/golangci-lint-action/compare/v6...v7) --- updated-dependencies: - dependency-name: golangci/golangci-lint-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * use latest golangci-lint * pin to v2.0 * golangci-lint migrate * make the tests happy --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Wade Simmons --- .github/workflows/test.yml | 8 +++---- .golangci.yaml | 26 +++++++++++++++++----- cert/crypto_test.go | 4 ++-- cmd/nebula-cert/ca_test.go | 40 +++++++++++++++++----------------- cmd/nebula-cert/keygen_test.go | 20 ++++++++--------- cmd/nebula-cert/print_test.go | 16 +++++++------- cmd/nebula-cert/verify_test.go | 32 +++++++++++++-------------- control_test.go | 2 +- firewall_test.go | 2 +- hostmap_test.go | 4 ++-- 10 files changed, 84 insertions(+), 70 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28f0590..006115d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,9 +32,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64 + version: v2.0 - name: Test run: make test @@ -115,9 +115,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64 + version: v2.0 - name: Test run: make test diff --git a/.golangci.yaml b/.golangci.yaml index f792069..bd82a95 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,9 +1,23 @@ -# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +version: "2" linters: - # Disable all linters. - # Default: false - disable-all: true - # Enable specific linter - # https://golangci-lint.run/usage/linters/#enabled-by-default + default: none enable: - testifylint + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/cert/crypto_test.go b/cert/crypto_test.go index ee671c0..6358ba6 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -10,14 +10,14 @@ import ( func TestNewArgon2Parameters(t *testing.T) { p := NewArgon2Parameters(64*1024, 4, 3) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 64 * 1024, Parallelism: 4, Iterations: 3, }, p) p = NewArgon2Parameters(2*1024*1024, 2, 1) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 2 * 1024 * 1024, Parallelism: 2, diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 189fc02..b1cbde9 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -90,26 +90,26 @@ func Test_ca(t *testing.T) { assertHelpError(t, ca( []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only ips assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only subnets assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") @@ -121,8 +121,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") @@ -135,8 +135,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, nopw)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) @@ -158,7 +158,7 @@ func Test_ca(t *testing.T) { assert.Empty(t, lCrt.UnsafeNetworks()) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) - assert.Equal(t, "", lCrt.Issuer()) + assert.Empty(t, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key @@ -169,7 +169,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // read encrypted key file and verify default params rb, _ = os.ReadFile(keyF.Name()) @@ -197,7 +197,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) @@ -207,7 +207,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) @@ -222,8 +222,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // test that we won't overwrite existing key file os.Remove(keyF.Name()) @@ -231,8 +231,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) os.Remove(keyF.Name()) } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 7eed5d2..95d9893 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -37,20 +37,20 @@ func Test_keygen(t *testing.T) { // required args assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") @@ -62,8 +62,8 @@ func Test_keygen(t *testing.T) { eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") @@ -75,8 +75,8 @@ func Test_keygen(t *testing.T) { eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} require.NoError(t, keygen(args, ob, eb)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 061e472..221ab77 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -43,16 +43,16 @@ func Test_printCert(t *testing.T) { // no path err := printCert([]string{}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, err, "-path is required") // no cert at path ob.Reset() eb.Reset() err = printCert([]string{"-path", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path @@ -64,8 +64,8 @@ func Test_printCert(t *testing.T) { tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs @@ -155,7 +155,7 @@ func Test_printCert(t *testing.T) { `, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test json ob.Reset() @@ -177,7 +177,7 @@ func Test_printCert(t *testing.T) { `, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) } // NewTestCaCert will generate a CA cert diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index acc9cca..f555e5f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -38,19 +38,19 @@ func Test_verify(t *testing.T) { // required args assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // no ca at path ob.Reset() eb.Reset() err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path @@ -62,8 +62,8 @@ func Test_verify(t *testing.T) { caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later @@ -76,8 +76,8 @@ func Test_verify(t *testing.T) { // no crt at path err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path @@ -89,8 +89,8 @@ func Test_verify(t *testing.T) { certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path @@ -106,8 +106,8 @@ func Test_verify(t *testing.T) { certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path @@ -118,7 +118,7 @@ func Test_verify(t *testing.T) { certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.NoError(t, err) } diff --git a/control_test.go b/control_test.go index de85fee..e400992 100644 --- a/control_test.go +++ b/control_test.go @@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - assert.EqualValues(t, &expectedInfo, thi) + assert.Equal(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet diff --git a/firewall_test.go b/firewall_test.go index c90fb20..4731a6f 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -792,7 +792,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.Equal(t, "", ob.String()) + assert.Empty(t, ob.String()) require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright diff --git a/hostmap_test.go b/hostmap_test.go index e974340..b3580cf 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) { assert.Empty(t, hm.GetPreferredRanges()) c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") - assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") - assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } From e136d1d47a630c9ac2de01949f3f3286fa110c23 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 1 Apr 2025 17:08:03 -0400 Subject: [PATCH 15/44] Update example config with default_local_cidr_any changes (#1373) --- CHANGELOG.md | 7 +++++++ examples/config.yml | 18 ++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad17147..1de3c19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- `default_local_cidr_any` now defaults to false, meaning that any firewall rule + intended to target an `unsafe_routes` entry must explicitly declare it via the + `local_cidr` field. This is almost always the intended behavior. This flag is + deprecated and will be removed in a future release. + ## [1.9.4] - 2024-09-09 ### Added diff --git a/examples/config.yml b/examples/config.yml index 3b7c38b..534608d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -346,11 +346,11 @@ firewall: outbound_action: drop inbound_action: drop - # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false. - # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an - # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless - # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr` - # if the intention is to allow traffic to flow to an unsafe route. + # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.) + # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a + # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule + # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr` + # is explicitly defined. This is usually not the desired behavior and should be avoided! #default_local_cidr_any: false conntrack: @@ -368,11 +368,9 @@ firewall: # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. - # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. - # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. - # Otherwise the default is any vpn network assigned to via the certificate. - # `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release. - # If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes. + # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. + # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum From 58ead4116ff6de08b56a7a32f930663cc9d2e9c4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:10:20 -0500 Subject: [PATCH 16/44] Bump github.com/gaissmai/bart from 0.18.1 to 0.20.1 (#1369) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bbd5d8b..1b6be0b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.18.1 + github.com/gaissmai/bart v0.20.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 diff --git a/go.sum b/go.sum index 8237bfa..f142a20 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ= -github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= +github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo= +github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From e4bae1582556a264f7629b4d368098db1efbf723 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:23:35 -0500 Subject: [PATCH 17/44] Bump google.golang.org/protobuf in the protobuf-dependencies group (#1365) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1b6be0b..19e83ab 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.36.5 + google.golang.org/protobuf v1.36.6 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index f142a20..fa8c29b 100644 --- a/go.sum +++ b/go.sum @@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From d99fd60e0622dd48b5ea67a14a251dc44efd404d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:26:23 -0500 Subject: [PATCH 18/44] Bump Apple-Actions/import-codesign-certs from 3 to 5 (#1364) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f9df115..3107b47 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v3 + uses: Apple-Actions/import-codesign-certs@v5 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} From e2d6f4e444d51d46f2a4715836fdf7d116187148 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:28:27 -0500 Subject: [PATCH 19/44] Bump github.com/miekg/dns from 1.1.63 to 1.1.64 (#1363) --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 19e83ab..7302092 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.63 + github.com/miekg/dns v1.1.64 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.21.1 @@ -50,7 +50,7 @@ require ( github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.18.0 // indirect + golang.org/x/mod v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.22.0 // indirect + golang.org/x/tools v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index fa8c29b..030d6ef 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= -github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= +github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ= +github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -164,8 +164,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= -golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= -golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From f5d096dd2b719e75b0b41bdb7a0e1fd8f86b02bb Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 2 Apr 2025 09:11:34 -0400 Subject: [PATCH 20/44] move to golang.org/x/term (#1372) The `golang.org/x/crypto/ssh/terminal` was deprecated and moved to `golang.org/x/term`. We already use the new package in `cmd/nebula-cert`, so fix our remaining reference here. See: - https://github.com/golang/go/issues/31044 --- sshd/session.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sshd/session.go b/sshd/session.go index 03b20cd..87cc216 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -9,13 +9,13 @@ import ( "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/term" ) type session struct { l *logrus.Entry c *ssh.ServerConn - term *terminal.Terminal + term *term.Terminal commands *radix.Tree exitChan chan bool } @@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { } } -func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { - term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") +func (s *session) createTerm(channel ssh.Channel) *term.Terminal { + term := term.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab if key == 9 { From e83a1c6c84a323636122b81919a56426e6ba4e7d Mon Sep 17 00:00:00 2001 From: Zeroday BYTE <47859767+odaysec@users.noreply.github.com> Date: Fri, 4 Apr 2025 01:11:20 +0700 Subject: [PATCH 21/44] Update config.go (#1353) --- config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index b1531e9..5510324 100644 --- a/config/config.go +++ b/config/config.go @@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int { // GetUint32 will get the uint32 for k or return the default d if not found or invalid func (c *C) GetUint32(k string, d uint32) uint32 { r := c.GetInt(k, int(d)) - if uint64(r) > uint64(math.MaxUint32) { + if r < 0 || uint64(r) > uint64(math.MaxUint32) { return d } return uint32(r) From d4a7df30836008745cf27c63a29bfa713da7f11e Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 7 Apr 2025 18:08:29 -0400 Subject: [PATCH 22/44] Rename pki.default_version to pki.initiating_version (#1381) --- connection_manager.go | 2 +- connection_manager_test.go | 16 +++++++-------- examples/config.yml | 4 ++-- handshake_ix.go | 4 ++-- handshake_manager_test.go | 10 +++++----- interface.go | 4 ++-- lighthouse.go | 6 +++--- lighthouse_test.go | 2 +- pki.go | 40 +++++++++++++++++++------------------- 9 files changed, 44 insertions(+), 44 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index 9d8d071..5c9b3a5 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := n.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert myCrt := cs.getCertificate(curCrt.Version()) - if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { // The current tunnel is using the latest certificate and version, no need to rehandshake. return } diff --git a/connection_manager_test.go b/connection_manager_test.go index 2c9baa1..d1c5ba3 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - defaultVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -126,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) { hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - defaultVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() diff --git a/examples/config.yml b/examples/config.yml index 534608d..d8e7e6e 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,11 +13,11 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true - # default_version controls which certificate version is used in handshakes. + # initiating_version controls which certificate version is used when initiating handshakes. # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. - # default_version: 1 + # initiating_version: 1 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. diff --git a/handshake_ix.go b/handshake_ix.go index 0783999..571a19a 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { // If we're connecting to a v6 address we must use a v2 cert cs := f.pki.getCertState() - v := cs.defaultVersion + v := cs.initiatingVersion for _, a := range hh.hostinfo.vpnAddrs { if a.Is6() { v = cert.Version2 @@ -101,7 +101,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if crt == nil { f.l.WithField("udpAddr", addr). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", cs.defaultVersion). + WithField("certVersion", cs.initiatingVersion). Error("Unable to handshake with host because no certificate is available") } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 4b898af..2e6d34b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -24,10 +24,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { lh := newTestLighthouse() cs := &CertState{ - defaultVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -98,5 +98,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { } func (mw *mockEncWriter) GetCertState() *CertState { - return &CertState{defaultVersion: cert.Version2} + return &CertState{initiatingVersion: cert.Version2} } diff --git a/interface.go b/interface.go index 21e198c..a15e2c2 100644 --- a/interface.go +++ b/interface.go @@ -410,7 +410,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { udpStats := udp.NewUDPStatsEmitter(f.writers) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) - certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil) + certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) for { @@ -425,7 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { certState := f.pki.getCertState() defaultCrt := certState.GetDefaultCertificate() certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) - certDefaultVersion.Update(int64(defaultCrt.Version())) + certInitiatingVersion.Update(int64(defaultCrt.Version())) // Report the max certificate version we are capable of using if certState.v2Cert != nil { diff --git a/lighthouse.go b/lighthouse.go index f13afd3..eb09a39 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -763,7 +763,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if hi != nil { v = hi.ConnectionState.myCert.Version() } else { - v = lh.ifce.GetCertState().defaultVersion + v = lh.ifce.GetCertState().initiatingVersion } if v == cert.Version1 { @@ -883,7 +883,7 @@ func (lh *LightHouse) SendUpdate() { if hi != nil { v = hi.ConnectionState.myCert.Version() } else { - v = lh.ifce.GetCertState().defaultVersion + v = lh.ifce.GetCertState().initiatingVersion } if v == cert.Version1 { if v1Update == nil { @@ -1114,7 +1114,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) var useVersion cert.Version if targetHI == nil { - useVersion = lhh.lh.ifce.GetCertState().defaultVersion + useVersion = lhh.lh.ifce.GetCertState().initiatingVersion } else { crt := targetHI.GetCert().Certificate useVersion = crt.Version() diff --git a/lighthouse_test.go b/lighthouse_test.go index 6a541c2..c49615c 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -417,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { } func (tw *testEncWriter) GetCertState() *CertState { - return &CertState{defaultVersion: tw.protocolVersion} + return &CertState{initiatingVersion: tw.protocolVersion} } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match diff --git a/pki.go b/pki.go index 888da7c..c9f8d89 100644 --- a/pki.go +++ b/pki.go @@ -33,10 +33,10 @@ type CertState struct { v2Cert cert.Certificate v2HandshakeBytes []byte - defaultVersion cert.Version - privateKey []byte - pkcs11Backed bool - cipher string + initiatingVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Table[struct{}] @@ -194,7 +194,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { } func (cs *CertState) GetDefaultCertificate() cert.Certificate { - c := cs.getCertificate(cs.defaultVersion) + c := cs.getCertificate(cs.initiatingVersion) if c == nil { panic("No default certificate found") } @@ -317,28 +317,28 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, errors.New("no certificates found in pki.cert") } - useDefaultVersion := uint32(1) + useInitiatingVersion := uint32(1) if v1 == nil { // The only condition that requires v2 as the default is if only a v2 certificate is present // We do this to avoid having to configure it specifically in the config file - useDefaultVersion = 2 + useInitiatingVersion = 2 } - rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) - var defaultVersion cert.Version - switch rawDefaultVersion { + rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion) + var initiatingVersion cert.Version + switch rawInitiatingVersion { case 1: if v1 == nil { - return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert") } - defaultVersion = cert.Version1 + initiatingVersion = cert.Version1 case 2: - defaultVersion = cert.Version2 + initiatingVersion = cert.Version2 default: - return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) + return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) } func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { @@ -361,7 +361,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p //TODO: CERT-V2 make sure v2 has v1s address - cs.defaultVersion = dv + cs.initiatingVersion = dv } if v1 != nil { @@ -380,8 +380,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs.v1Cert = v1 cs.v1HandshakeBytes = v1hs - if cs.defaultVersion == 0 { - cs.defaultVersion = cert.Version1 + if cs.initiatingVersion == 0 { + cs.initiatingVersion = cert.Version1 } } @@ -401,8 +401,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs.v2Cert = v2 cs.v2HandshakeBytes = v2hs - if cs.defaultVersion == 0 { - cs.defaultVersion = cert.Version2 + if cs.initiatingVersion == 0 { + cs.initiatingVersion = cert.Version2 } } From c7fb3ad9cfc2b1d7249c965aace0035e83753cc6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:39:31 -0400 Subject: [PATCH 23/44] Bump the golang-x-dependencies group with 4 updates (#1382) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/sync](https://github.com/golang/sync), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.36.0 to 0.37.0 - [Commits](https://github.com/golang/crypto/compare/v0.36.0...v0.37.0) Updates `golang.org/x/sync` from 0.12.0 to 0.13.0 - [Commits](https://github.com/golang/sync/compare/v0.12.0...v0.13.0) Updates `golang.org/x/sys` from 0.31.0 to 0.32.0 - [Commits](https://github.com/golang/sys/compare/v0.31.0...v0.32.0) Updates `golang.org/x/term` from 0.30.0 to 0.31.0 - [Commits](https://github.com/golang/term/compare/v0.30.0...v0.31.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-version: 0.37.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sync dependency-version: 0.13.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-version: 0.32.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-version: 0.31.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 7302092..5db8b67 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.10.0 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.36.0 + golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.38.0 - golang.org/x/sync v0.12.0 - golang.org/x/sys v0.31.0 - golang.org/x/term v0.30.0 + golang.org/x/sync v0.13.0 + golang.org/x/sys v0.32.0 + golang.org/x/term v0.31.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 030d6ef..f258360 100644 --- a/go.sum +++ b/go.sum @@ -156,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -185,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -204,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 18279ed17b10f75b109511f4d7af99854920e137 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:40:34 -0400 Subject: [PATCH 24/44] Bump github.com/miekg/dns from 1.1.64 to 1.1.65 (#1384) Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.64 to 1.1.65. - [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release) - [Commits](https://github.com/miekg/dns/compare/v1.1.64...v1.1.65) --- updated-dependencies: - dependency-name: github.com/miekg/dns dependency-version: 1.1.65 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 5db8b67..62a57b3 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.64 + github.com/miekg/dns v1.1.65 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.21.1 diff --git a/go.sum b/go.sum index f258360..26357d7 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ= -github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= +github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc= +github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From 459cb38a6d463f4b4800f6844e139eb7b20ea31d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 11:46:46 -0400 Subject: [PATCH 25/44] Bump github.com/gaissmai/bart from 0.20.1 to 0.20.4 (#1391) * Bump github.com/gaissmai/bart from 0.20.1 to 0.20.4 Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.20.1 to 0.20.4. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.20.1...v0.20.4) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-version: 0.20.4 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * set back to go 1.23.0 We were only on 1.23.6 because of bart in the first place. --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Wade Simmons --- go.mod | 4 ++-- go.sum | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 62a57b3..9e10ad6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/slackhq/nebula -go 1.23.6 +go 1.23.0 toolchain go1.24.1 @@ -10,7 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.20.1 + github.com/gaissmai/bart v0.20.4 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 diff --git a/go.sum b/go.sum index 26357d7..2a08b64 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo= -github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= +github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U= +github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From e49f2790041d80af54ec9ba3783d7239fb662977 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 06:41:53 -0400 Subject: [PATCH 26/44] Bump golang.org/x/net in the golang-x-dependencies group (#1392) Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/net` from 0.38.0 to 0.39.0 - [Commits](https://github.com/golang/net/compare/v0.38.0...v0.39.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-version: 0.39.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 9e10ad6..b5e371c 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/vishvananda/netlink v1.3.0 golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.38.0 + golang.org/x/net v0.39.0 golang.org/x/sync v0.13.0 golang.org/x/sys v0.32.0 golang.org/x/term v0.31.0 diff --git a/go.sum b/go.sum index 2a08b64..e644a29 100644 --- a/go.sum +++ b/go.sum @@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From 4eb056af9dd58ce6e911c28011a0a7b8508e8b93 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 06:43:55 -0400 Subject: [PATCH 27/44] Bump github.com/prometheus/client_golang from 1.21.1 to 1.22.0 (#1393) Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.21.1 to 1.22.0. - [Release notes](https://github.com/prometheus/client_golang/releases) - [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md) - [Commits](https://github.com/prometheus/client_golang/compare/v1.21.1...v1.22.0) --- updated-dependencies: - dependency-name: github.com/prometheus/client_golang dependency-version: 1.22.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 3 +-- go.sum | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index b5e371c..d90a937 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/miekg/dns v1.1.65 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.21.1 + github.com/prometheus/client_golang v1.22.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -43,7 +43,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect - github.com/klauspost/compress v1.17.11 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect diff --git a/go.sum b/go.sum index e644a29..920ee57 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= -github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= -github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= From b8ea55eb90bfaee6966499d16168848d7cc7e4a2 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 18 Apr 2025 12:37:20 -0400 Subject: [PATCH 28/44] optimize usage of bart (#1395) Use `bart.Lite` and `.Contains` as suggested by the bart maintainer: - https://github.com/gaissmai/bart/commit/9455952eedcf59a6e755fc28ed16e906fa4f3066#commitcomment-155362580 --- control.go | 3 +-- dns_server.go | 6 +++--- firewall.go | 25 +++++++++++-------------- handshake_ix.go | 7 +++---- handshake_manager.go | 3 +-- hostmap.go | 8 ++++---- inside.go | 9 +++------ interface.go | 10 +++++----- lighthouse.go | 23 ++++++++--------------- lighthouse_test.go | 20 ++++++++++---------- outside.go | 3 +-- pki.go | 18 +++++++++--------- relay_manager.go | 6 ++---- 13 files changed, 61 insertions(+), 80 deletions(-) diff --git a/control.go b/control.go index 20dd7fe..016a79b 100644 --- a/control.go +++ b/control.go @@ -131,8 +131,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { - _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) - if found { + if c.f.myVpnAddrsTable.Contains(vpnIp) { // Only returning the default certificate since its impossible // for any other host but ourselves to have more than 1 return c.f.pki.getCertState().GetDefaultCertificate().Copy() diff --git a/dns_server.go b/dns_server.go index 710f6ed..7357654 100644 --- a/dns_server.go +++ b/dns_server.go @@ -26,7 +26,7 @@ type dnsRecords struct { dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap - myVpnAddrsTable *bart.Table[struct{}] + myVpnAddrsTable *bart.Lite } func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { @@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return true } - _, found := d.myVpnAddrsTable.Lookup(b) - return found //if we found it in this table, it's good + //if we found it in this table, it's good + return d.myVpnAddrsTable.Contains(b) } func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { diff --git a/firewall.go b/firewall.go index e730114..971c156 100644 --- a/firewall.go +++ b/firewall.go @@ -53,7 +53,7 @@ type Firewall struct { // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // The vpn addresses are a full bit match while the unsafe networks only match the prefix - routableNetworks *bart.Table[struct{}] + routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. assignedNetworks []netip.Prefix @@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *bart.Table[struct{}] + LocalCIDR *bart.Lite } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } - routableNetworks := new(bart.Table[struct{}]) + routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - routableNetworks.Insert(nprefix, struct{}{}) + routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - routableNetworks.Insert(n, struct{}{}) + routableNetworks.Insert(n) hasUnsafeNetworks = true } @@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if h.networks != nil { - _, ok := h.networks.Lookup(fp.RemoteAddr) - if !ok { + if !h.networks.Contains(fp.RemoteAddr) { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } @@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - _, ok := f.routableNetworks.Lookup(fp.LocalAddr) - if !ok { + if !f.routableNetworks.Contains(fp.LocalAddr) { f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: new(bart.Table[struct{}]), + LocalCIDR: new(bart.Lite), } } @@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { } for _, network := range f.assignedNetworks { - flc.LocalCIDR.Insert(network, struct{}{}) + flc.LocalCIDR.Insert(network) } return nil @@ -888,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } - flc.LocalCIDR.Insert(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp) return nil } @@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) - return ok + return flc.LocalCIDR.Contains(p.LocalAddr) } type rule struct { diff --git a/handshake_ix.go b/handshake_ix.go index 571a19a..cf422b9 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -192,8 +192,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet for _, network := range remoteCert.Certificate.Networks() { vpnAddr := network.Addr() - _, found := f.myVpnAddrsTable.Lookup(vpnAddr) - if found { + if f.myVpnAddrsTable.Contains(vpnAddr) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). @@ -204,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } // vpnAddrs outside our vpn networks are of no use to us, filter them out - if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + if !f.myVpnNetworksTable.Contains(vpnAddr) { continue } @@ -579,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha for _, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out vpnAddr := network.Addr() - if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + if !f.myVpnNetworksTable.Contains(vpnAddr) { continue } diff --git a/handshake_manager.go b/handshake_manager.go index 6f95402..486541b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } // Don't relay through the host I'm trying to connect to - _, found := hm.f.myVpnAddrsTable.Lookup(relay) - if found { + if hm.f.myVpnAddrsTable.Contains(relay) { continue } diff --git a/hostmap.go b/hostmap.go index f9e3c4e..359749b 100644 --- a/hostmap.go +++ b/hostmap.go @@ -223,7 +223,7 @@ type HostInfo struct { recvError atomic.Uint32 // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Table[struct{}] + networks *bart.Lite relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { return } - i.networks = new(bart.Table[struct{}]) + i.networks = new(bart.Lite) for _, network := range networks { - i.networks.Insert(network, struct{}{}) + i.networks.Insert(network) } for _, network := range unsafeNetworks { - i.networks.Insert(network, struct{}{}) + i.networks.Insert(network) } } diff --git a/inside.go b/inside.go index 0af350d..239ea6a 100644 --- a/inside.go +++ b/inside.go @@ -22,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) - if found { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { return } } - _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) - if found { + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula addr to the Nebula addr through the Nebula @@ -130,8 +128,7 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) { // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - _, found := f.myVpnNetworksTable.Lookup(vpnAddr) - if found { + if f.myVpnNetworksTable.Contains(vpnAddr) { return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } diff --git a/interface.go b/interface.go index a15e2c2..ddd0681 100644 --- a/interface.go +++ b/interface.go @@ -61,11 +61,11 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - myBroadcastAddrsTable *bart.Table[struct{}] - myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate - myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate - myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate - myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + myBroadcastAddrsTable *bart.Lite + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Lite + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate + myVpnNetworksTable *bart.Lite dropLocalBroadcast bool dropMulticast bool routines int diff --git a/lighthouse.go b/lighthouse.go index eb09a39..7a679c7 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -32,7 +32,7 @@ type LightHouse struct { amLighthouse bool myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] + myVpnNetworksTable *bart.Lite punchConn udp.Conn punchy *Punchy @@ -201,8 +201,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() - _, found := lh.myVpnNetworksTable.Lookup(addr) - if found { + if lh.myVpnNetworksTable.Contains(addr) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue @@ -359,8 +358,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - _, found := lh.myVpnNetworksTable.Lookup(addr) - if !found { + if !lh.myVpnNetworksTable.Contains(addr) { return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } lhMap[addr] = struct{}{} @@ -431,8 +429,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) - if !found { + if !lh.myVpnNetworksTable.Contains(vpnAddr) { return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } @@ -653,8 +650,7 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { return false } - _, found := lh.myVpnNetworksTable.Lookup(to) - if found { + if lh.myVpnNetworksTable.Contains(to) { return false } @@ -674,8 +670,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo return false } - _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { + if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } @@ -695,8 +690,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo return false } - _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { + if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } @@ -856,8 +850,7 @@ func (lh *LightHouse) SendUpdate() { lal := lh.GetLocalAllowList() for _, e := range localAddrs(lh.l, lal) { - _, found := lh.myVpnNetworksTable.Lookup(e) - if found { + if lh.myVpnNetworksTable.Contains(e) { continue } diff --git a/lighthouse_test.go b/lighthouse_test.go index c49615c..eb2d26e 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) { c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) { c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, diff --git a/outside.go b/outside.go index 1e9cde1..3a7b3a7 100644 --- a/outside.go +++ b/outside.go @@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) - if found { + if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } diff --git a/pki.go b/pki.go index c9f8d89..9cab491 100644 --- a/pki.go +++ b/pki.go @@ -39,10 +39,10 @@ type CertState struct { cipher string myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] + myVpnNetworksTable *bart.Lite myVpnAddrs []netip.Addr - myVpnAddrsTable *bart.Table[struct{}] - myVpnBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrsTable *bart.Lite + myVpnBroadcastAddrsTable *bart.Lite } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -345,9 +345,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, - myVpnNetworksTable: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), - myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + myVpnNetworksTable: new(bart.Lite), + myVpnAddrsTable: new(bart.Lite), + myVpnBroadcastAddrsTable: new(bart.Lite), } if v1 != nil && v2 != nil { @@ -415,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p for _, network := range crt.Networks() { cs.myVpnNetworks = append(cs.myVpnNetworks, network) - cs.myVpnNetworksTable.Insert(network, struct{}{}) + cs.myVpnNetworksTable.Insert(network) cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) - cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen())) if network.Addr().Is4() { addr := network.Masked().Addr().As4() mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) - cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen())) } } diff --git a/relay_manager.go b/relay_manager.go index 7565350..5dd355c 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - _, found := f.myVpnAddrsTable.Lookup(from) - if found { + if f.myVpnAddrsTable.Contains(from) { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - _, found = f.myVpnAddrsTable.Lookup(target) - if found { + if f.myVpnAddrsTable.Contains(target) { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { From 2dc30fc300c0451e92bdfe463057594f8af66766 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 21 Apr 2025 13:28:43 -0400 Subject: [PATCH 29/44] Support 32-bit machines in crypto test (#1394) --- cert/crypto_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cert/crypto_test.go b/cert/crypto_test.go index 6358ba6..174b241 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -26,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) { } func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { - passphrase := []byte("DO NOT USE THIS KEY") + passphrase := []byte("DO NOT USE") privKey := []byte(`# A good key -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA +CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI +c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0 +KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU +0BGkUHmIO7daP24= -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- `) shortKey := []byte(`# A key which, once decrypted, is too short -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 -k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe -GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs -rQr3bdH3Oy/WiYU= +CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo +hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z +cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7 +hqzIyrRqfUgVuA== -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- `) invalidBanner := []byte(`# Invalid banner (not encrypted) From e5ce8966d6730affd57a3cf4937a14ba88a855f1 Mon Sep 17 00:00:00 2001 From: Andriyanov Nikita Date: Mon, 21 Apr 2025 20:44:33 +0300 Subject: [PATCH 30/44] add netlink options (#1326) * add netlink options * force use buffer * fix namings and add config examples * fix linter --- examples/config.yml | 4 ++++ overlay/tun_linux.go | 40 +++++++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index d8e7e6e..eec4f1c 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -275,6 +275,10 @@ tun: # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false + # Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default). + # If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss. + # SO_RCVBUFFORCE is used to avoid having to raise the system wide max + #use_system_route_table_buffer_size: 0 # Configure logging level logging: diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7d19c85..4c509ba 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -34,10 +34,11 @@ type tun struct { deviceIndex int ioctlFd uintptr - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - routeChan chan struct{} - useSystemRoutes bool + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + routeChan chan struct{} + useSystemRoutes bool + useSystemRoutesBufferSize int l *logrus.Logger } @@ -124,12 +125,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), - vpnNetworks: vpnNetworks, - TXQueueLen: c.GetInt("tun.tx_queue", 500), - useSystemRoutes: c.GetBool("tun.use_system_route_table", false), - l: l, + ReadWriteCloser: file, + fd: int(file.Fd()), + vpnNetworks: vpnNetworks, + TXQueueLen: c.GetInt("tun.tx_queue", 500), + useSystemRoutes: c.GetBool("tun.use_system_route_table", false), + useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), + l: l, } err := t.reload(c, true) @@ -531,7 +533,13 @@ func (t *tun) watchRoutes() { rch := make(chan netlink.RouteUpdate) doneChan := make(chan struct{}) - if err := netlink.RouteSubscribe(rch, doneChan); err != nil { + netlinkOptions := netlink.RouteSubscribeOptions{ + ReceiveBufferSize: t.useSystemRoutesBufferSize, + ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, + ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, + } + + if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { t.l.WithError(err).Errorf("failed to subscribe to system route changes") return } @@ -541,8 +549,14 @@ func (t *tun) watchRoutes() { go func() { for { select { - case r := <-rch: - t.updateRoutes(r) + case r, ok := <-rch: + if ok { + t.updateRoutes(r) + } else { + // may be should do something here as + // netlink stops sending updates + return + } case <-doneChan: // netlink.RouteSubscriber will close the rch for us return From 15b5a4330034e7f93b449c1253d2f8f4691de726 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 21 Apr 2025 13:45:48 -0400 Subject: [PATCH 31/44] Update issue and PR templates (#1376) --- .github/ISSUE_TEMPLATE/config.yml | 20 ++++++++++++++------ .github/pull_request_template.md | 11 +++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) create mode 100644 .github/pull_request_template.md diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 94e2c6b..07e1580 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,13 +1,21 @@ blank_issues_enabled: true contact_links: + - name: 💨 Performance Issues + url: https://github.com/slackhq/nebula/discussions/new/choose + about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.' + + - name: 📄 Documentation Issues + url: https://github.com/definednet/nebula-docs + about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository." + + - name: 📱 Mobile Nebula Issues + url: https://github.com/definednet/mobile_nebula + about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository." + - name: 📘 Documentation url: https://nebula.defined.net/docs/ - about: Review documentation. + about: 'The documentation is the best place to start if you are new to Nebula.' - name: 💁 Support/Chat url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU - about: 'This issue tracker is not for support questions. Join us on Slack for assistance!' - - - name: 📱 Mobile Nebula - url: https://github.com/definednet/mobile_nebula - about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!' + about: 'For faster support, join us on Slack for assistance!' diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..102ddb3 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,11 @@ + From 8536c5764565fcb0381c37772811b92c001d1b89 Mon Sep 17 00:00:00 2001 From: maggie44 <64841595+maggie44@users.noreply.github.com> Date: Mon, 21 Apr 2025 18:45:59 +0100 Subject: [PATCH 32/44] Allow configuration of logger and build version in gvisor service library (#1239) --- examples/go_service/main.go | 15 ++++++++++++++- service/service.go | 12 +----------- service/service_test.go | 14 +++++++++++++- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 30178c0..2f8efbf 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -5,8 +5,12 @@ import ( "fmt" "log" "net" + "os" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) @@ -59,7 +63,16 @@ pki: if err := cfg.LoadString(configStr); err != nil { return err } - svc, err := service.New(&cfg) + + logger := logrus.New() + logger.Out = os.Stdout + + ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) + if err != nil { + return err + } + + svc, err := service.New(ctrl) if err != nil { return err } diff --git a/service/service.go b/service/service.go index 4339677..fc8ac97 100644 --- a/service/service.go +++ b/service/service.go @@ -9,13 +9,10 @@ import ( "math" "net" "net/netip" - "os" "strings" "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "golang.org/x/sync/errgroup" "gvisor.dev/gvisor/pkg/buffer" @@ -46,14 +43,7 @@ type Service struct { } } -func New(config *config.C) (*Service, error) { - logger := logrus.New() - logger.Out = os.Stdout - - control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) - if err != nil { - return nil, err - } +func New(control *nebula.Control) (*Service, error) { control.Start() ctx := control.Context() diff --git a/service/service_test.go b/service/service_test.go index b9810cd..f1c91a7 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -5,13 +5,17 @@ import ( "context" "errors" "net/netip" + "os" "testing" "time" "dario.cat/mergo" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v3" ) @@ -71,7 +75,15 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n panic(err) } - s, err := New(&c) + logger := logrus.New() + logger.Out = os.Stdout + + control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) + if err != nil { + panic(err) + } + + s, err := New(control) if err != nil { panic(err) } From 83ff2461e29ba466ef7121f1b1fb1ac070c40e84 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 28 Apr 2025 13:36:06 -0400 Subject: [PATCH 33/44] Mention CA expiration in the README (#1378) --- README.md | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5eea0e2..0284087 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers. Nebula incorporates a number of existing concepts like encryption, security groups, certificates, -and tunneling, and each of those individual pieces existed before Nebula in various forms. +and tunneling. What makes Nebula different to existing offerings is that it brings all of these ideas together, resulting in a sum that is greater than its individual parts. @@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for ## Technical Overview -Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/). +Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/). Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups. Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes. -Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs. +Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs. Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme. Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration. @@ -82,19 +82,22 @@ To set up a Nebula network, you'll need: #### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse. -Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses. - - Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet. +Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses. +Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet. #### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network. - ``` - ./nebula-cert ca -name "Myorganization, Inc" - ``` - This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption. +``` +./nebula-cert ca -name "Myorganization, Inc" +``` + +This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption. + +**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA. #### 4. Nebula host keys and certificates generated from that certificate authority + This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network. ``` ./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" @@ -103,7 +106,10 @@ This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You ./nebula-cert sign -name "host3" -ip "192.168.100.10/24" ``` +By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime. + #### 5. Configuration files for each host + Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml). * On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set. @@ -118,10 +124,13 @@ For each host, copy the nebula binary to the host, along with `config.yml` from **DO NOT COPY `ca.key` TO INDIVIDUAL NODES.** #### 7. Run nebula on each host + ``` ./nebula -config /path/to/config.yml ``` +For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/). + ## Building Nebula from source Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory. @@ -140,8 +149,10 @@ The default curve used for cryptographic handshakes and signatures is Curve25519 In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets: - make bin-boringcrypto - make release-boringcrypto +``` +make bin-boringcrypto +make release-boringcrypto +``` This is not the recommended default deployment, but may be useful based on your compliance requirements. @@ -149,5 +160,3 @@ This is not the recommended default deployment, but may be useful based on your Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. - - From 92a924808364e95383714bacc006dec777c6bd72 Mon Sep 17 00:00:00 2001 From: Andy George Date: Fri, 2 May 2025 15:32:00 -0500 Subject: [PATCH 34/44] Minor fixes to Readme shell snippets (#1389) --- README.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 0284087..3208239 100644 --- a/README.md +++ b/README.md @@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for #### Distribution Packages - [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/) - ``` - $ sudo pacman -S nebula + ```sh + sudo pacman -S nebula ``` - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula) - ``` - $ sudo dnf install nebula + ```sh + sudo dnf install nebula ``` - [Debian Linux](https://packages.debian.org/source/stable/nebula) - ``` - $ sudo apt install nebula + ```sh + sudo apt install nebula ``` - [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula) - ``` - $ sudo apk add nebula + ```sh + sudo apk add nebula ``` - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb) - ``` - $ brew install nebula + ```sh + brew install nebula ``` - [Docker](https://hub.docker.com/r/nebulaoss/nebula) - ``` - $ docker pull nebulaoss/nebula + ```sh + docker pull nebulaoss/nebula ``` #### Mobile @@ -88,7 +88,7 @@ Once you have launched an instance, ensure that Nebula udp traffic (default port #### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network. -``` +```sh ./nebula-cert ca -name "Myorganization, Inc" ``` @@ -99,7 +99,7 @@ This will create files named `ca.key` and `ca.cert` in the current directory. Th #### 4. Nebula host keys and certificates generated from that certificate authority This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network. -``` +```sh ./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" ./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh" ./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers" @@ -125,7 +125,7 @@ For each host, copy the nebula binary to the host, along with `config.yml` from #### 7. Run nebula on each host -``` +```sh ./nebula -config /path/to/config.yml ``` @@ -149,7 +149,7 @@ The default curve used for cryptographic handshakes and signatures is Curve25519 In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets: -``` +```sh make bin-boringcrypto make release-boringcrypto ``` From 061e733007ceae26decfefcd199b2a19c7afeb7e Mon Sep 17 00:00:00 2001 From: Ian VanSchooten Date: Tue, 13 May 2025 12:00:22 -0400 Subject: [PATCH 35/44] Fix slack invitation link in issue template (#1406) --- .github/ISSUE_TEMPLATE/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 07e1580..fe7dbcd 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -17,5 +17,5 @@ contact_links: about: 'The documentation is the best place to start if you are new to Nebula.' - name: 💁 Support/Chat - url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU + url: https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ about: 'For faster support, join us on Slack for assistance!' From 442a52879b6b19f5b455d96a3b2a3f1e3fe57649 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:15:15 -0400 Subject: [PATCH 36/44] Fix off by one error in IPv6 packet parser (#1419) --- outside.go | 9 ++++----- outside_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/outside.go b/outside.go index 3a7b3a7..6d4127d 100644 --- a/outside.go +++ b/outside.go @@ -312,12 +312,11 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { offset := ipv6.HeaderLen // Start at the end of the ipv6 header next := 0 for { - if dataLen < offset { + if protoAt >= dataLen { break } - proto := layers.IPProtocol(data[protoAt]) - //fmt.Println(proto, protoAt) + switch proto { case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) @@ -365,7 +364,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { case layers.IPProtocolAH: // Auth headers, used by IPSec, have a different meaning for header length - if dataLen < offset+1 { + if dataLen <= offset+1 { break } @@ -373,7 +372,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { default: // Normal ipv6 header length processing - if dataLen < offset+1 { + if dataLen <= offset+1 { break } diff --git a/outside_test.go b/outside_test.go index c63e57d..38dbef6 100644 --- a/outside_test.go +++ b/outside_test.go @@ -117,6 +117,45 @@ func Test_newPacket_v6(t *testing.T) { err = newPacket(buffer.Bytes(), true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + // A v6 packet with a hop-by-hop extension + // ICMPv6 Payload (Echo Request) + icmpLayer := layers.ICMPv6{ + TypeCode: layers.ICMPv6TypeEchoRequest, + } + // Hop-by-Hop Extension Header + hopOption := layers.IPv6HopByHopOption{} + hopOption.OptionData = []byte{0, 0, 0, 0} + hopByHop := layers.IPv6HopByHop{} + hopByHop.Options = append(hopByHop.Options, &hopOption) + + ip = layers.IPv6{ + Version: 6, + HopLimit: 128, + NextHeader: layers.IPProtocolIPv6Destination, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: true, + }, &ip, &hopByHop, &icmpLayer) + if err != nil { + panic(err) + } + // Ensure buffer length checks during parsing with the next 2 tests. + + // A full IPv6 header and 1 byte in the first extension, but missing + // the length byte. + err = newPacket(buffer.Bytes()[:41], true, p) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A full IPv6 header plus 1 full extension, but only 1 byte of the + // next layer, missing length byte + err = newPacket(buffer.Bytes()[:49], true, p) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + // A good ICMP packet ip = layers.IPv6{ Version: 6, @@ -288,6 +327,10 @@ func Test_newPacket_v6(t *testing.T) { assert.Equal(t, uint16(22), p.LocalPort) assert.False(t, p.Fragment) + // Ensure buffer bounds checking during processing + err = newPacket(b[:41], true, p) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) + // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) From d34c2b8e066ed5f356cb2668a77ed67ec26b2783 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Jul 2025 23:25:24 -0500 Subject: [PATCH 37/44] Bump golangci/golangci-lint-action from 7 to 8 (#1400) * Bump golangci/golangci-lint-action from 7 to 8 Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 7 to 8. - [Release notes](https://github.com/golangci/golangci-lint-action/releases) - [Commits](https://github.com/golangci/golangci-lint-action/compare/v7...v8) --- updated-dependencies: - dependency-name: golangci/golangci-lint-action dependency-version: '8' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * bump golangci-lint version --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Wade Simmons --- .github/workflows/test.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 006115d..00b3936 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,9 +32,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v7 + uses: golangci/golangci-lint-action@v8 with: - version: v2.0 + version: v2.1 - name: Test run: make test @@ -115,9 +115,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v7 + uses: golangci/golangci-lint-action@v8 with: - version: v2.0 + version: v2.1 - name: Test run: make test From 882edf11d77e1ff8b1df37a4db1db0fed2da3a74 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Jul 2025 23:29:15 -0500 Subject: [PATCH 38/44] Bump github.com/vishvananda/netlink from 1.3.0 to 1.3.1 (#1407) Bumps [github.com/vishvananda/netlink](https://github.com/vishvananda/netlink) from 1.3.0 to 1.3.1. - [Release notes](https://github.com/vishvananda/netlink/releases) - [Commits](https://github.com/vishvananda/netlink/compare/v1.3.0...v1.3.1) --- updated-dependencies: - dependency-name: github.com/vishvananda/netlink dependency-version: 1.3.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index d90a937..a345ed0 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.10.0 - github.com/vishvananda/netlink v1.3.0 + github.com/vishvananda/netlink v1.3.1 golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.39.0 @@ -48,7 +48,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/mod v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.30.0 // indirect diff --git a/go.sum b/go.sum index 920ee57..6aae53a 100644 --- a/go.sum +++ b/go.sum @@ -145,10 +145,10 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= -github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= From e4b7dbcfb03678a880caa3c3fba21b8656a15fb5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Jul 2025 23:30:40 -0500 Subject: [PATCH 39/44] Bump dario.cat/mergo from 1.0.1 to 1.0.2 (#1408) Bumps [dario.cat/mergo](https://github.com/imdario/mergo) from 1.0.1 to 1.0.2. - [Release notes](https://github.com/imdario/mergo/releases) - [Commits](https://github.com/imdario/mergo/compare/v1.0.1...v1.0.2) --- updated-dependencies: - dependency-name: dario.cat/mergo dependency-version: 1.0.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index a345ed0..d552a7c 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.0 toolchain go1.24.1 require ( - dario.cat/mergo v1.0.1 + dario.cat/mergo v1.0.2 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 diff --git a/go.sum b/go.sum index 6aae53a..a932e58 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= -dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= From b158eb0c4cfe42e63edc0bd1027e52bff18441de Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:47:05 -0400 Subject: [PATCH 40/44] Use a list for relay IPs instead of a map (#1423) * Use a list for relay IPs instead of a map * linter --- control_test.go | 4 ++-- handshake_ix.go | 2 +- handshake_manager.go | 2 +- hostmap.go | 20 +++++++++++++------- hostmap_test.go | 29 +++++++++++++++++++++++++++++ 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/control_test.go b/control_test.go index e400992..e8a5d31 100644 --- a/control_test.go +++ b/control_test.go @@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { localIndexId: 201, vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, @@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { localIndexId: 201, vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/handshake_ix.go b/handshake_ix.go index cf422b9..0548a23 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -249,7 +249,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/handshake_manager.go b/handshake_manager.go index 486541b..f92e72d 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -450,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/hostmap.go b/hostmap.go index 359749b..7b9b8b9 100644 --- a/hostmap.go +++ b/hostmap.go @@ -4,6 +4,7 @@ import ( "errors" "net" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -68,7 +69,7 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with // the RelayState Lock held) @@ -79,7 +80,12 @@ type RelayState struct { func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() - delete(rs.relays, ip) + for idx, val := range rs.relays { + if val == ip { + rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...) + return + } + } } func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { @@ -124,16 +130,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() - rs.relays[ip] = struct{}{} + if !slices.Contains(rs.relays, ip) { + rs.relays = append(rs.relays, ip) + } } func (rs *RelayState) CopyRelayIps() []netip.Addr { + ret := make([]netip.Addr, len(rs.relays)) rs.RLock() defer rs.RUnlock() - ret := make([]netip.Addr, 0, len(rs.relays)) - for ip := range rs.relays { - ret = append(ret, ip) - } + copy(ret, rs.relays) return ret } diff --git a/hostmap_test.go b/hostmap_test.go index b3580cf..e34a4ad 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHostMap_MakePrimary(t *testing.T) { @@ -215,3 +216,31 @@ func TestHostMap_reload(t *testing.T) { c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } + +func TestHostMap_RelayState(t *testing.T) { + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + a1 := netip.MustParseAddr("::1") + a2 := netip.MustParseAddr("2001::1") + + h1.relayState.InsertRelayTo(a1) + assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) + h1.relayState.InsertRelayTo(a2) + assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays) + // Ensure that the first relay added is the first one returned in the copy + currentRelays := h1.relayState.CopyRelayIps() + require.Len(t, currentRelays, 2) + assert.Equal(t, a1, currentRelays[0]) + + // Deleting the last one in the list works ok + h1.relayState.DeleteRelay(a2) + assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) + + // Deleting an element not in the list works ok + h1.relayState.DeleteRelay(a2) + assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) + + // Deleting the only element in the list works ok + h1.relayState.DeleteRelay(a1) + assert.Equal(t, []netip.Addr{}, h1.relayState.relays) + +} From 94142aded5b111ee8e5afeac73ba49f9fda3a0d0 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:48:02 -0400 Subject: [PATCH 41/44] Fix relay migration panic by covering every possible relay state (#1414) --- connection_manager.go | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index 5c9b3a5..f3acc92 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "net/netip" "sync" "time" @@ -227,21 +228,25 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) var relayFrom netip.Addr var relayTo netip.Addr switch { - case ok && existing.State == Established: - // This relay already exists in newhostinfo, then do nothing. - continue - case ok && existing.State == Requested: - // The relay exists in a Requested state; re-send the request - index = existing.LocalIndex - switch r.Type { - case TerminalType: - relayFrom = n.intf.myVpnAddrs[0] - relayTo = existing.PeerAddr - case ForwardingType: - relayFrom = existing.PeerAddr - relayTo = newhostinfo.vpnAddrs[0] - default: - // should never happen + case ok: + switch existing.State { + case Established, PeerRequested, Disestablished: + // This relay already exists in newhostinfo, then do nothing. + continue + case Requested: + // The relay exists in a Requested state; re-send the request + index = existing.LocalIndex + switch r.Type { + case TerminalType: + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr + case ForwardingType: + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] + default: + // should never happen + panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type)) + } } case !ok: n.relayUsedLock.RLock() @@ -267,6 +272,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayTo = newhostinfo.vpnAddrs[0] default: // should never happen + panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type)) } } From b3a1f7b0a3053cf277a8af68a48fe3e2cff1e56e Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:37:41 -0400 Subject: [PATCH 42/44] Disable UDP receive error returns due to ICMP messages on Windows. (#1412) (#1415) --- udp/udp_rio_windows.go | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 585b642..886e024 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -92,6 +92,25 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { // Enable v4 for this socket syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + // Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call. + // These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable + // the UDP receive error returns with these ioctl calls. + ret := uint32(0) + flag := uint32(0) + size := uint32(unsafe.Sizeof(flag)) + err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0) + if err != nil { + return err + } + ret = 0 + flag = 0 + size = uint32(unsafe.Sizeof(flag)) + SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15) + err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0) + if err != nil { + return err + } + err = u.rx.Open() if err != nil { return err @@ -122,8 +141,12 @@ func (u *RIOConn) ListenOut(r EncReader) { // Just read one packet at a time n, rua, err := u.receive(buffer) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + u.l.WithError(err).Error("unexpected udp socket receive error") + continue } r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) From c2420642a098b0a0a16089df2e9f1e982a0a123f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 2 Jul 2025 15:50:22 -0500 Subject: [PATCH 43/44] Darwin udp fix (#1428) --- udp/errors.go | 5 ++ udp/udp_darwin.go | 164 ++++++++++++++++++++++++++++++++++++++++++--- udp/udp_generic.go | 3 +- udp/udp_linux.go | 2 +- 4 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 udp/errors.go diff --git a/udp/errors.go b/udp/errors.go new file mode 100644 index 0000000..12a8487 --- /dev/null +++ b/udp/errors.go @@ -0,0 +1,5 @@ +package udp + +import "errors" + +var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote") diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 183ac7a..5e50d8b 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -6,17 +6,61 @@ package udp // Darwin support is primarily implemented in udp_generic, besides NewListenConfig import ( + "context" + "encoding/binary" + "errors" "fmt" "net" "net/netip" "syscall" + "unsafe" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) +type StdConn struct { + *net.UDPConn + isV4 bool + sysFd uintptr + l *logrus.Logger +} + +var _ Conn = &StdConn{} + func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - return NewGenericListener(l, ip, port, multi, batch) + lc := NewListenConfig(multi) + pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) + if err != nil { + return nil, err + } + + if uc, ok := pc.(*net.UDPConn); ok { + c := &StdConn{UDPConn: uc, l: l} + + rc, err := uc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("failed to open udp socket: %w", err) + } + + err = rc.Control(func(fd uintptr) { + c.sysFd = fd + }) + if err != nil { + return nil, fmt.Errorf("failed to get udp fd: %w", err) + } + + la, err := c.LocalAddr() + if err != nil { + return nil, err + } + c.isV4 = la.Addr().Is4() + + return c, nil + } + + return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc) } func NewListenConfig(multi bool) net.ListenConfig { @@ -43,16 +87,116 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *GenericConn) Rebind() error { - rc, err := u.UDPConn.SyscallConn() - if err != nil { - return err +//go:linkname sendto golang.org/x/sys/unix.sendto +//go:noescape +func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error) + +func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { + var sa unsafe.Pointer + var addrLen int32 + + if u.isV4 { + if ap.Addr().Is6() { + return ErrInvalidIPv6RemoteForSocket + } + + var rsa unix.RawSockaddrInet6 + rsa.Family = unix.AF_INET6 + rsa.Addr = ap.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) + sa = unsafe.Pointer(&rsa) + addrLen = syscall.SizeofSockaddrInet4 + } else { + var rsa unix.RawSockaddrInet6 + rsa.Family = unix.AF_INET6 + rsa.Addr = ap.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) + sa = unsafe.Pointer(&rsa) + addrLen = syscall.SizeofSockaddrInet6 } - return rc.Control(func(fd uintptr) { - err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) - if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + // Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves + // See https://github.com/golang/go/issues/73919 + for { + //_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen) + err := sendto(int(u.sysFd), b, 0, sa, addrLen) + if err == nil { + // Written, get out before the error handling + return nil } - }) + + if errors.Is(err, syscall.EINTR) { + // Write was interrupted, retry + continue + } + + if errors.Is(err, syscall.EAGAIN) { + return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK} + } + + if errors.Is(err, syscall.EBADF) { + return net.ErrClosed + } + + return &net.OpError{Op: "sendto", Err: err} + } +} + +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { + a := u.UDPConn.LocalAddr() + + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil + + default: + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) + } +} + +func (u *StdConn) ReloadConfig(c *config.C) { + // TODO +} + +func NewUDPStatsEmitter(udpConns []Conn) func() { + // No UDP stats for non-linux + return func() {} +} + +func (u *StdConn) ListenOut(r EncReader) { + buffer := make([]byte, MTU) + + for { + // Just read one packet at a time + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + } +} + +func (u *StdConn) Rebind() error { + var err error + if u.isV4 { + err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0) + } else { + err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0) + } + + if err != nil { + u.l.WithError(err).Error("Failed to rebind udp socket") + } + + return nil } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 06a4d53..cb21e57 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -1,6 +1,7 @@ -//go:build (!linux || android) && !e2e_testing +//go:build (!linux || android) && !e2e_testing && !darwin // +build !linux android // +build !e2e_testing +// +build !darwin // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. diff --git a/udp/udp_linux.go b/udp/udp_linux.go index f1936b4..ec0bf64 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -221,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { if !ip.Addr().Is4() { - return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + return ErrInvalidIPv6RemoteForSocket } var rsa unix.RawSockaddrInet4 From 52623820c2be9571bb46acd16f5afd7811fe6542 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 3 Jul 2025 09:58:37 -0500 Subject: [PATCH 44/44] Drop inactive tunnels (#1427) --- connection_manager.go | 383 ++++++++++++++++++++----------------- connection_manager_test.go | 176 +++++++++++++---- control.go | 20 +- e2e/handshakes_test.go | 5 +- e2e/router/router.go | 1 + e2e/tunnels_test.go | 57 ++++++ examples/config.yml | 12 ++ handshake_ix.go | 4 +- hostmap.go | 8 + inside.go | 4 +- interface.go | 41 ++-- main.go | 45 ++--- outside.go | 6 +- udp/udp_darwin.go | 2 - 14 files changed, 485 insertions(+), 279 deletions(-) create mode 100644 e2e/tunnels_test.go diff --git a/connection_manager.go b/connection_manager.go index f3acc92..1f9b18b 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -7,11 +7,13 @@ import ( "fmt" "net/netip" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -28,130 +30,124 @@ const ( ) type connectionManager struct { - in map[uint32]struct{} - inLock *sync.RWMutex - - out map[uint32]struct{} - outLock *sync.RWMutex - // relayUsed holds which relay localIndexs are in use relayUsed map[uint32]struct{} relayUsedLock *sync.RWMutex - hostMap *HostMap - trafficTimer *LockingTimerWheel[uint32] - intf *Interface - pendingDeletion map[uint32]struct{} - punchy *Punchy + hostMap *HostMap + trafficTimer *LockingTimerWheel[uint32] + intf *Interface + punchy *Punchy + + // Configuration settings checkInterval time.Duration pendingDeletionInterval time.Duration - metricsTxPunchy metrics.Counter + inactivityTimeout atomic.Int64 + dropInactive atomic.Bool + + metricsTxPunchy metrics.Counter l *logrus.Logger } -func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { - var max time.Duration - if checkInterval < pendingDeletionInterval { - max = pendingDeletionInterval - } else { - max = checkInterval +func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { + cm := &connectionManager{ + hostMap: hm, + l: l, + punchy: p, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, + metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), } - nc := &connectionManager{ - hostMap: intf.hostMap, - in: make(map[uint32]struct{}), - inLock: &sync.RWMutex{}, - out: make(map[uint32]struct{}), - outLock: &sync.RWMutex{}, - relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, - trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), - intf: intf, - pendingDeletion: make(map[uint32]struct{}), - checkInterval: checkInterval, - pendingDeletionInterval: pendingDeletionInterval, - punchy: punchy, - metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), - l: l, - } + cm.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + cm.reload(c, false) + }) - nc.Start(ctx) - return nc + return cm } -func (n *connectionManager) In(localIndex uint32) { - n.inLock.RLock() - // If this already exists, return - if _, ok := n.in[localIndex]; ok { - n.inLock.RUnlock() - return +func (cm *connectionManager) reload(c *config.C, initial bool) { + if initial { + cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second + cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second + + // We want at least a minimum resolution of 500ms per tick so that we can hit these intervals + // pretty close to their configured duration. + // The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it. + minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval) + maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval) + cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration) + } + + if initial || c.HasChanged("tunnels.inactivity_timeout") { + old := cm.getInactivityTimeout() + cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) + if !initial { + cm.l.WithField("oldDuration", old). + WithField("newDuration", cm.getInactivityTimeout()). + Info("Inactivity timeout has changed") + } + } + + if initial || c.HasChanged("tunnels.drop_inactive") { + old := cm.dropInactive.Load() + cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) + if !initial { + cm.l.WithField("oldBool", old). + WithField("newBool", cm.dropInactive.Load()). + Info("Drop inactive setting has changed") + } } - n.inLock.RUnlock() - n.inLock.Lock() - n.in[localIndex] = struct{}{} - n.inLock.Unlock() } -func (n *connectionManager) Out(localIndex uint32) { - n.outLock.RLock() - // If this already exists, return - if _, ok := n.out[localIndex]; ok { - n.outLock.RUnlock() - return - } - n.outLock.RUnlock() - n.outLock.Lock() - n.out[localIndex] = struct{}{} - n.outLock.Unlock() +func (cm *connectionManager) getInactivityTimeout() time.Duration { + return (time.Duration)(cm.inactivityTimeout.Load()) } -func (n *connectionManager) RelayUsed(localIndex uint32) { - n.relayUsedLock.RLock() +func (cm *connectionManager) In(h *HostInfo) { + h.in.Store(true) +} + +func (cm *connectionManager) Out(h *HostInfo) { + h.out.Store(true) +} + +func (cm *connectionManager) RelayUsed(localIndex uint32) { + cm.relayUsedLock.RLock() // If this already exists, return - if _, ok := n.relayUsed[localIndex]; ok { - n.relayUsedLock.RUnlock() + if _, ok := cm.relayUsed[localIndex]; ok { + cm.relayUsedLock.RUnlock() return } - n.relayUsedLock.RUnlock() - n.relayUsedLock.Lock() - n.relayUsed[localIndex] = struct{}{} - n.relayUsedLock.Unlock() + cm.relayUsedLock.RUnlock() + cm.relayUsedLock.Lock() + cm.relayUsed[localIndex] = struct{}{} + cm.relayUsedLock.Unlock() } // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // resets the state for this local index -func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { - n.inLock.Lock() - n.outLock.Lock() - _, in := n.in[localIndex] - _, out := n.out[localIndex] - delete(n.in, localIndex) - delete(n.out, localIndex) - n.inLock.Unlock() - n.outLock.Unlock() +func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) { + in := h.in.Swap(false) + out := h.out.Swap(false) + if in || out { + h.lastUsed = now + } return in, out } -func (n *connectionManager) AddTrafficWatch(localIndex uint32) { - // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index - n.outLock.Lock() - if _, ok := n.out[localIndex]; ok { - n.outLock.Unlock() - return +// AddTrafficWatch must be called for every new HostInfo. +// We will continue to monitor the HostInfo until the tunnel is dropped. +func (cm *connectionManager) AddTrafficWatch(h *HostInfo) { + if h.out.Swap(true) == false { + cm.trafficTimer.Add(h.localIndexId, cm.checkInterval) } - n.out[localIndex] = struct{}{} - n.trafficTimer.Add(localIndex, n.checkInterval) - n.outLock.Unlock() } -func (n *connectionManager) Start(ctx context.Context) { - go n.Run(ctx) -} - -func (n *connectionManager) Run(ctx context.Context) { - //TODO: this tick should be based on the min wheel tick? Check firewall - clockSource := time.NewTicker(500 * time.Millisecond) +func (cm *connectionManager) Start(ctx context.Context) { + clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration) defer clockSource.Stop() p := []byte("") @@ -164,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) { return case now := <-clockSource.C: - n.trafficTimer.Advance(now) + cm.trafficTimer.Advance(now) for { - localIndex, has := n.trafficTimer.Purge() + localIndex, has := cm.trafficTimer.Purge() if !has { break } - n.doTrafficCheck(localIndex, p, nb, out, now) + cm.doTrafficCheck(localIndex, p, nb, out, now) } } } } -func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { - decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) +func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { + decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now) switch decision { case deleteTunnel: - if n.hostMap.DeleteHostInfo(hostinfo) { + if cm.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) + cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: - n.intf.sendCloseTunnel(hostinfo) - n.intf.closeTunnel(hostinfo) + cm.intf.sendCloseTunnel(hostinfo) + cm.intf.closeTunnel(hostinfo) case swapPrimary: - n.swapPrimary(hostinfo, primary) + cm.swapPrimary(hostinfo, primary) case migrateRelays: - n.migrateRelayUsed(hostinfo, primary) + cm.migrateRelayUsed(hostinfo, primary) case tryRehandshake: - n.tryRehandshake(hostinfo) + cm.tryRehandshake(hostinfo) case sendTestPacket: - n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } - n.resetRelayTrafficCheck(hostinfo) + cm.resetRelayTrafficCheck(hostinfo) } -func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { +func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { if hostinfo != nil { - n.relayUsedLock.Lock() - defer n.relayUsedLock.Unlock() + cm.relayUsedLock.Lock() + defer cm.relayUsedLock.Unlock() // No need to migrate any relays, delete usage info now. for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { - delete(n.relayUsed, idx) + delete(cm.relayUsed, idx) } } } -func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { +func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { @@ -238,7 +234,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnAddrs[0] + relayFrom = cm.intf.myVpnAddrs[0] relayTo = existing.PeerAddr case ForwardingType: relayFrom = existing.PeerAddr @@ -249,23 +245,23 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } case !ok: - n.relayUsedLock.RLock() - if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { + cm.relayUsedLock.RLock() + if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed { // The relay hasn't been used; don't migrate it. - n.relayUsedLock.RUnlock() + cm.relayUsedLock.RUnlock() continue } - n.relayUsedLock.RUnlock() + cm.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) + index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - n.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnAddrs[0] + relayFrom = cm.intf.myVpnAddrs[0] relayTo = r.PeerAddr case ForwardingType: relayFrom = r.PeerAddr @@ -285,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) switch newhostinfo.GetCert().Certificate.Version() { case cert.Version1: if !relayFrom.Is4() { - n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !relayTo.Is4() { - n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -302,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) req.RelayFromAddr = netAddrToProtoAddr(relayFrom) req.RelayToAddr = netAddrToProtoAddr(relayTo) default: - newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay") continue } msg, err := req.Marshal() if err != nil { - n.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { - n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - n.l.WithFields(logrus.Fields{ + cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) + cm.l.WithFields(logrus.Fields{ "relayFrom": req.RelayFromAddr, "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, @@ -322,46 +318,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } -func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { - n.hostMap.RLock() - defer n.hostMap.RUnlock() +func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { + // Read lock the main hostmap to order decisions based on tunnels being the primary tunnel + cm.hostMap.RLock() + defer cm.hostMap.RUnlock() - hostinfo := n.hostMap.Indexes[localIndex] + hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - delete(n.pendingDeletion, localIndex) + cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") return doNothing, nil, nil } - if n.isInvalidCertificate(now, hostinfo) { - delete(n.pendingDeletion, hostinfo.localIndexId) + if cm.isInvalidCertificate(now, hostinfo) { return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] + primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false } // Check for traffic on this hostinfo - inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) + inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now) // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } - delete(n.pendingDeletion, hostinfo.localIndexId) + hostinfo.pendingDeletion.Store(false) if mainHostInfo { decision = tryRehandshake } else { - if n.shouldSwapPrimary(hostinfo, primary) { + if cm.shouldSwapPrimary(hostinfo, primary) { decision = swapPrimary } else { // migrate the relays to the primary, if in use. @@ -369,46 +364,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time } } - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) if !outTraffic { // Send a punch packet to keep the NAT state alive - n.sendPunch(hostinfo) + cm.sendPunch(hostinfo) } return decision, hostinfo, primary } - if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { + if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(n.l). + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "dead", "method": "active"}). Info("Tunnel status") - delete(n.pendingDeletion, hostinfo.localIndexId) return deleteTunnel, hostinfo, nil } decision := doNothing if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { + inactiveFor, isInactive := cm.isInactive(hostinfo, now) + if isInactive { + // Tunnel is inactive, tear it down + hostinfo.logger(cm.l). + WithField("inactiveDuration", inactiveFor). + WithField("primary", mainHostInfo). + Info("Dropping tunnel due to inactivity") + + return closeTunnel, hostinfo, primary + } + // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - n.sendPunch(hostinfo) - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + cm.sendPunch(hostinfo) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil - } - if n.punchy.GetTargetEverything() { + if cm.punchy.GetTargetEverything() { // This is similar to the old punchy behavior with a slight optimization. // We aren't receiving traffic but we are sending it, punch on all known // ips in case we need to re-prime NAT state - n.sendPunch(hostinfo) + cm.sendPunch(hostinfo) } - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") } @@ -417,17 +421,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time decision = sendTestPacket } else { - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l).Debugf("Hostinfo sadness") + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l).Debugf("Hostinfo sadness") } } - n.pendingDeletion[hostinfo.localIndexId] = struct{}{} - n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) + hostinfo.pendingDeletion.Store(true) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval) return decision, hostinfo, nil } -func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { +func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) { + if cm.dropInactive.Load() == false { + // We aren't configured to drop inactive tunnels + return 0, false + } + + inactiveDuration := now.Sub(hostinfo.lastUsed) + if inactiveDuration < cm.getInactivityTimeout() { + // It's not considered inactive + return inactiveDuration, false + } + + // The tunnel is inactive + return inactiveDuration, true +} + +func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. @@ -435,73 +455,80 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // Only one side should swap because if both swap then we may never resolve to a single tunnel. // vpn addr is static across all tunnels for this host pair so lets // use that to determine if we should consider swapping. - if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 { // Their primary vpn addr is less than mine. Do not swap. return false } - crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } -func (n *connectionManager) swapPrimary(current, primary *HostInfo) { - n.hostMap.Lock() +func (cm *connectionManager) swapPrimary(current, primary *HostInfo) { + cm.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { - n.hostMap.unlockedMakePrimary(current) + if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary { + cm.hostMap.unlockedMakePrimary(current) } - n.hostMap.Unlock() + cm.hostMap.Unlock() } // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid // check and return true. -func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { +func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { return false } - caPool := n.intf.pki.GetCAPool() + caPool := cm.intf.pki.GetCAPool() err := caPool.VerifyCachedCertificate(now, remoteCert) if err == nil { return false } - if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { + if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { // Block listed certificates should always be disconnected return false } - hostinfo.logger(n.l).WithError(err). + hostinfo.logger(cm.l).WithError(err). WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true } -func (n *connectionManager) sendPunch(hostinfo *HostInfo) { - if !n.punchy.GetPunch() { +func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { + if !cm.punchy.GetPunch() { // Punching is disabled return } - if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { - n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, addr) + if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. + // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse + // would lose the ability to notify us and punchy.respond would become unreliable. + return + } + + if cm.punchy.GetTargetEverything() { + hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { + cm.metricsTxPunchy.Inc(1) + cm.intf.outside.WriteTo([]byte{1}, addr) }) } else if hostinfo.remote.IsValid() { - n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + cm.metricsTxPunchy.Inc(1) + cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } -func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - cs := n.intf.pki.getCertState() +func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { + cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert myCrt := cs.getCertificate(curCrt.Version()) if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { @@ -509,9 +536,9 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { return } - n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index d1c5ba3..ecd2880 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -1,7 +1,6 @@ package nebula import ( - "context" "crypto/ed25519" "crypto/rand" "net/netip" @@ -64,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -85,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp - nc.Out(hostinfo.localIndexId) - nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + nc.Out(hostinfo) + nc.In(hostinfo) + assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.out, hostinfo.localIndexId) + assert.True(t, hostinfo.out.Load()) + assert.True(t, hostinfo.in.Load()) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now - nc.Out(hostinfo.localIndexId) + nc.Out(hostinfo) + assert.True(t, hostinfo.out.Load()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.True(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } @@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -167,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) { nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp - nc.Out(hostinfo.localIndexId) - nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + nc.Out(hostinfo) + nc.In(hostinfo) + assert.True(t, hostinfo.in.Load()) + assert.True(t, hostinfo.out.Load()) + assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now - nc.Out(hostinfo.localIndexId) + nc.Out(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.True(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion - nc.In(hostinfo.localIndexId) + nc.In(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) +} + +func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")} + preferredRanges := []netip.Prefix{localrange} + + // Very incomplete mock objects + hostMap := newHostMap(l) + hostMap.preferredRanges.Store(&preferredRanges) + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &test.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + } + ifce.pki.cs.Store(cs) + + // Create manager + conf := config.NewC(l) + conf.Settings["tunnels"] = map[string]any{ + "drop_inactive": true, + } + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + assert.True(t, nc.dropInactive.Load()) + nc.intf = ifce + + // Add an ip we have established a connection w/ to hostmap + hostinfo := &HostInfo{ + vpnAddrs: vpnAddrs, + localIndexId: 1099, + remoteIndexId: 9901, + } + hostinfo.ConnectionState = &ConnectionState{ + myCert: &dummyCert{version: cert.Version1}, + H: &noise.HandshakeState{}, + } + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) + + // Do a traffic check tick, in and out should be cleared but should not be pending deletion + nc.Out(hostinfo) + nc.In(hostinfo) + assert.True(t, hostinfo.out.Load()) + assert.True(t, hostinfo.in.Load()) + + now := time.Now() + decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now) + assert.Equal(t, tryRehandshake, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5)) + assert.Equal(t, doNothing, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + + // Do another traffic check tick, should still not be pending deletion + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10)) + assert.Equal(t, doNothing, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) + + // Finally advance beyond the inactivity timeout + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10)) + assert.Equal(t, closeTunnel, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } @@ -264,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce ifce.connectionManager = nc hostinfo := &HostInfo{ diff --git a/control.go b/control.go index 016a79b..f8567b5 100644 --- a/control.go +++ b/control.go @@ -26,14 +26,15 @@ type controlHostLister interface { } type Control struct { - f *Interface - l *logrus.Logger - ctx context.Context - cancel context.CancelFunc - sshStart func() - statsStart func() - dnsStart func() - lighthouseStart func() + f *Interface + l *logrus.Logger + ctx context.Context + cancel context.CancelFunc + sshStart func() + statsStart func() + dnsStart func() + lighthouseStart func() + connectionManagerStart func(context.Context) } type ControlHostInfo struct { @@ -63,6 +64,9 @@ func (c *Control) Start() { if c.dnsStart != nil { go c.dnsStart() } + if c.connectionManagerStart != nil { + go c.connectionManagerStart(c.ctx) + } if c.lighthouseStart != nil { c.lighthouseStart() } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index bc080ce..53d3738 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) { curIndexes := len(myControl.GetHostmap().Indexes) for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) - r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -1052,6 +1052,9 @@ func TestRehandshakingLoser(t *testing.T) { t.Log("Stand up a tunnel between me and them") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") diff --git a/e2e/router/router.go b/e2e/router/router.go index 5e52ed7..c8264ab 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -700,6 +700,7 @@ func (r *R) FlushAll() { r.Unlock() panic("Can't FlushAll for host: " + p.To.String()) } + receiver.InjectUDPPacket(p) r.Unlock() } } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go new file mode 100644 index 0000000..55974f0 --- /dev/null +++ b/e2e/tunnels_test.go @@ -0,0 +1,57 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" +) + +func TestDropInactiveTunnels(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.Log("Go inactive and wait for the tunnels to get dropped") + waitStart := time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 && theirIndexes == 0 { + break + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*30 { + t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds") + } + + time.Sleep(1 * time.Second) + r.FlushAll() + } + + r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart)) + myControl.Stop() + theirControl.Stop() +} diff --git a/examples/config.yml b/examples/config.yml index eec4f1c..42c32c8 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -338,6 +338,18 @@ logging: # after receiving the response for lighthouse queries #trigger_buffer: 64 +# Tunnel manager settings +#tunnels: + # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has + # elapsed. + # In general, it is a good idea to enable this setting. It will be enabled by default in a future release. + # This setting is reloadable + #drop_inactive: false + + # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered + # inactive and eligible to be dropped. + # This setting is reloadable + #inactivity_timeout: 10m # Nebula security group configuration firewall: diff --git a/handshake_ix.go b/handshake_ix.go index 0548a23..d53e5a7 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -457,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message sent") } - f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + f.connectionManager.AddTrafficWatch(hostinfo) hostinfo.remotes.ResetBlockedRemotes() @@ -652,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) - f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + f.connectionManager.AddTrafficWatch(hostinfo) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) diff --git a/hostmap.go b/hostmap.go index 7b9b8b9..7e3b1bd 100644 --- a/hostmap.go +++ b/hostmap.go @@ -256,6 +256,14 @@ type HostInfo struct { // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. next, prev *HostInfo + + //TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing + in, out, pendingDeletion atomic.Bool + + // lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use. + // This value will be behind against actual tunnel utilization in the hot path. + // This should only be used by the ConnectionManagers ticker routine. + lastUsed time.Time } type ViaSender struct { diff --git a/inside.go b/inside.go index 239ea6a..d24ed31 100644 --- a/inside.go +++ b/inside.go @@ -288,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo, c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) - f.connectionManager.Out(via.localIndexId) + f.connectionManager.Out(via) // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. @@ -356,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) - f.connectionManager.Out(hostinfo.localIndexId) + f.connectionManager.Out(hostinfo) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // all our addrs and enable a faster roaming. diff --git a/interface.go b/interface.go index ddd0681..082906d 100644 --- a/interface.go +++ b/interface.go @@ -24,23 +24,23 @@ import ( const mtu = 9001 type InterfaceConfig struct { - HostMap *HostMap - Outside udp.Conn - Inside overlay.Device - pki *PKI - Firewall *Firewall - ServeDns bool - HandshakeManager *HandshakeManager - lightHouse *LightHouse - checkInterval time.Duration - pendingDeletionInterval time.Duration - DropLocalBroadcast bool - DropMulticast bool - routines int - MessageMetrics *MessageMetrics - version string - relayManager *relayManager - punchy *Punchy + HostMap *HostMap + Outside udp.Conn + Inside overlay.Device + pki *PKI + Cipher string + Firewall *Firewall + ServeDns bool + HandshakeManager *HandshakeManager + lightHouse *LightHouse + connectionManager *connectionManager + DropLocalBroadcast bool + DropMulticast bool + routines int + MessageMetrics *MessageMetrics + version string + relayManager *relayManager + punchy *Punchy tryPromoteEvery uint32 reQueryEvery uint32 @@ -157,6 +157,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Firewall == nil { return nil, errors.New("no firewall rules") } + if c.connectionManager == nil { + return nil, errors.New("no connection manager") + } cs := c.pki.getCertState() ifce := &Interface{ @@ -181,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { myVpnAddrsTable: cs.myVpnAddrsTable, myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, relayManager: c.relayManager, - + connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), @@ -198,7 +201,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) - ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) + ifce.connectionManager.intf = ifce return ifce, nil } diff --git a/main.go b/main.go index b278fa6..eb296fb 100644 --- a/main.go +++ b/main.go @@ -185,6 +185,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) + connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) @@ -220,31 +221,26 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - checkInterval := c.GetInt("timers.connection_alive_interval", 5) - pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10) - ifConfig := &InterfaceConfig{ - HostMap: hostMap, - Inside: tun, - Outside: udpConns[0], - pki: pki, - Firewall: fw, - ServeDns: serveDns, - HandshakeManager: handshakeManager, - lightHouse: lightHouse, - checkInterval: time.Second * time.Duration(checkInterval), - pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval), - tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), - reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), - reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), - DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), - DropMulticast: c.GetBool("tun.drop_multicast", false), - routines: routines, - MessageMetrics: messageMetrics, - version: buildVersion, - relayManager: NewRelayManager(ctx, l, hostMap, c), - punchy: punchy, - + HostMap: hostMap, + Inside: tun, + Outside: udpConns[0], + pki: pki, + Firewall: fw, + ServeDns: serveDns, + HandshakeManager: handshakeManager, + connectionManager: connManager, + lightHouse: lightHouse, + tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), + reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), + reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), + DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), + DropMulticast: c.GetBool("tun.drop_multicast", false), + routines: routines, + MessageMetrics: messageMetrics, + version: buildVersion, + relayManager: NewRelayManager(ctx, l, hostMap, c), + punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, } @@ -296,5 +292,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg statsStart, dnsStart, lightHouse.StartUpdateWorker, + connManager.Start, }, nil } diff --git a/outside.go b/outside.go index 6d4127d..8720eef 100644 --- a/outside.go +++ b/outside.go @@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] // Pull the Roaming parts up here, and return in all call paths. f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) @@ -213,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] f.handleHostRoaming(hostinfo, ip) - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -498,7 +498,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 5e50d8b..c0c6233 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -3,8 +3,6 @@ package udp -// Darwin support is primarily implemented in udp_generic, besides NewListenConfig - import ( "context" "encoding/binary"