Compare commits

...

28 Commits

Author SHA1 Message Date
JackDoan
f2bb43fb42 remove gso, leave nice pretty GRO 2025-11-07 11:24:14 -06:00
JackDoan
7999b62147 cursed gso 2025-11-06 17:56:46 -06:00
JackDoan
2ab75709ad hmm 2025-11-04 15:40:33 -06:00
Nate Brown
2ea8a72d5c dunno 2025-10-05 23:23:30 -05:00
Nate Brown
663232e1fc Testing the concept 2025-10-05 23:23:10 -05:00
Nate Brown
2f48529e8b Cleanup and note more work 2025-10-05 23:23:08 -05:00
Nate Brown
f3e1ad64cd Try the timeout 2025-10-05 23:22:29 -05:00
Nate Brown
1d8112a329 Revert "More playing" way too much garbage emitted
This reverts commit fa098c551a.
2025-10-05 23:22:29 -05:00
Nate Brown
31eea0cc94 More playing 2025-10-05 23:22:29 -05:00
Nate Brown
dbba4a4c77 Playing 2025-10-05 23:22:29 -05:00
Nate Brown
194fde45da non-blocking io for linux 2025-10-05 23:22:27 -05:00
Nate Brown
f46b83f2c4 Remove more os.Exit calls and give a more reliable wait for stop function 2025-10-05 23:20:43 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
dependabot[bot]
071589f7c7 Bump actions/setup-go from 5 to 6 (#1469)
* Bump actions/setup-go from 5 to 6

Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Hardcode the last one to go v1.25

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-10-02 00:05:12 -05:00
Jack Doan
f1e992f6dd don't require a detailsVpnAddr in a HostUpdateNotification (#1472)
* don't require a detailsVpnAddr in a HostUpdateNotification

* don't send our own addr on HostUpdateNotification for v2
2025-09-29 13:43:12 -05:00
Jack Doan
1ea5f776d7 update to go 1.25, use the cool new ECDSA key marshalling functions (#1483)
* update to go 1.25, use the cool new ECDSA key marshalling functions

* bonk the runners

* actually bump go.mod

* bump golangci-lint
2025-09-29 13:02:25 -05:00
Henry Graham
4cdeb284ef Set CKA_VALUE_LEN attribute in DeriveNoise (#1482) 2025-09-25 13:24:52 -05:00
Jack Doan
5cccd39465 update RemoteList.vpnAddrs when we complete a handshake (#1467) 2025-09-10 09:44:25 -05:00
Jack Doan
8196c22b5a store lighthouses as a slice (#1473)
* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
2025-09-10 09:43:25 -05:00
Jack Doan
65cc253c19 prevent linux from assigning ipv6 link-local addresses (#1476) 2025-09-09 13:25:23 -05:00
Wade Simmons
73cfa7b5b1 add firewall tests for ipv6 (#1451)
Test things like cidr and local_cidr with ipv6 addresses, to ensure
everything is working correctly.
2025-09-08 13:57:36 -04:00
Jack Doan
768325c9b4 cert-v2 chores (#1466) 2025-09-05 15:08:22 -05:00
Jack Doan
932e329164 Don't delete static host mappings for non-primary IPs (#1464)
* Don't delete a vpnaddr if it's part of a certificate that contains a vpnaddr that's in the static host map

* remove unused arg from ConnectionManager.shouldSwapPrimary()
2025-09-04 14:49:40 -05:00
Jack Doan
4bea299265 don't send recv errors for packets outside the connection window anymore (#1463)
* don't send recv errors for packets outside the connection window anymore

* Pull in fix from #1459, add my opinion on maxRecvError

* remove recv_error counter entirely
2025-09-03 11:52:52 -05:00
Wade Simmons
5cff83b282 netlink: ignore route updates with no destination (#1437)
Currently we assume each route update must have a destination, but we
should check that it is set before we try to use it.

See: #1436
2025-08-25 13:05:35 -05:00
Wade Simmons
7da79685ff fix lighthouse.calculated_remotes parsing (#1438)
Some checks failed
gofmt / Run gofmt (push) Successful in 27s
smoke-extra / Run extra smoke tests (push) Failing after 21s
smoke / Run multi node smoke test (push) Failing after 1m21s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m9s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m16s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m41s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
This was broken with the change to yaml.v3:

- https://github.com/slackhq/nebula/pull/1148

We forgot to update these references to `map[string]any`.

Without this fix, Nebula crashes with an error like this:

    {"error":"config `lighthouse.calculated_remotes` has invalid type: map[string]interface {}","level":"error","msg":"Invalid lighthouse.calculated_remotes","time":"2025-07-29T15:50:06.479499Z"}
2025-07-29 13:12:07 -04:00
47 changed files with 1806 additions and 383 deletions

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Install goimports - name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -35,9 +35,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -68,9 +68,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Import certificates - name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version-file: 'go.mod' go-version: '1.25'
check-latest: true check-latest: true
- name: add hashicorp source - name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: build - name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -34,7 +34,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v8 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.1 version: v2.5
- name: Test - name: Test
run: make test run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.22' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build nebula - name: Build nebula
@@ -117,7 +117,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v8 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.1 version: v2.5
- name: Test - name: Test
run: make test run: make test

View File

@@ -5,6 +5,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
type Bits struct { type Bits struct {
length uint64 length uint64
current uint64 current uint64
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Error("rejected a packet (top) %d %d\n", b.current, i)
return false return false
} }

View File

@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote]) calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any) rawMap, ok := value.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
} }
for rawKey, rawValue := range rawMap { for rawCIDR, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
cidr, err := netip.ParsePrefix(rawCIDR) cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil { if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
} }
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any) rawMap, ok := raw.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type: %T", raw) return nil, fmt.Errorf("invalid type: %T", raw)
} }

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations. // PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature. // Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve Curve() Curve
@@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2: case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default: default:
//TODO: CERT-V2 make a static var return nil, ErrUnknownVersion
return nil, fmt.Errorf("unknown certificate version %d", v)
} }
if err != nil { if err != nil {

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey return c.details.publicKey
} }
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte { func (c *certificateV1) Signature() []byte {
return c.signature return c.signature
} }
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -1,6 +1,7 @@
package cert package cert
import ( import (
"crypto/ed25519"
"fmt" "fmt"
"net/netip" "net/netip"
"testing" "testing"
@@ -13,6 +14,7 @@ import (
) )
func TestCertificateV1_Marshal(t *testing.T) { func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) { func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey return c.publicKey
} }
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte { func (c *certificateV2) Signature() []byte {
return c.signature return c.signature
} }
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -15,6 +15,7 @@ import (
) )
func TestCertificateV2_Marshal(t *testing.T) { func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) { func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{ nc := certificateV2{
details: detailsV2{ details: detailsV2{

View File

@@ -20,6 +20,7 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate") ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -7,19 +7,26 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
const ( const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE" CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2" CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" )
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" const ( //key-agreement-key banners
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
) )
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
} }
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b) k, r := pem.Decode(b)
if k == nil { if k == nil {
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner: case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32 expectedLen = 32
curve = Curve_CURVE25519 curve = Curve_CURVE25519
case P256PublicKeyBanner: case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed // Uncompressed
expectedLen = 65 expectedLen = 65
curve = Curve_P256 curve = Curve_P256

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalPublicKeyFromPEM(t *testing.T) { func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY----- -----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalX25519PublicKey(t *testing.T) { func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY----- -----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`) `)
shortKey := []byte(`# A short key shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`) -END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32) assert.Len(t, k, 32)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65) assert.Len(t, k, 65)

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"time" "time"
) )
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
} }
return t.SignWith(signer, curve, sp) return t.SignWith(signer, curve, sp)
case Curve_P256: case Curve_P256:
pk := &ecdsa.PrivateKey{ pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
PublicKey: ecdsa.PublicKey{ if err != nil {
Curve: elliptic.P256(), return nil, err
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
} }
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) { sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA // We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1 // - https://pkg.go.dev/crypto/ecdsa#SignASN1

191
cmd/gso/gso.go Normal file
View File

@@ -0,0 +1,191 @@
package main
import (
"encoding/binary"
"errors"
"flag"
"fmt"
"log"
"net"
"net/netip"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
const (
// UDP_SEGMENT enables GSO segmentation
UDP_SEGMENT = 103
// Maximum GSO segment size (typical MTU - headers)
maxGSOSize = 1400
)
func main() {
destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address")
gsoSize := flag.Int("gso", 1400, "GSO segment size")
totalSize := flag.Int("size", 14000, "Total payload size to send")
count := flag.Int("count", 1, "Number of packets to send")
flag.Parse()
if *gsoSize > maxGSOSize {
log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize)
}
// Resolve destination address
_, err := net.ResolveUDPAddr("udp", *destAddr)
if err != nil {
log.Fatalf("Failed to resolve address: %v", err)
}
// Create a raw UDP socket with GSO support
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err != nil {
log.Fatalf("Failed to create socket: %v", err)
}
defer unix.Close(fd)
// Bind to a local address
localAddr := &unix.SockaddrInet4{
Port: 0, // Let the system choose a port
}
if err := unix.Bind(fd, localAddr); err != nil {
log.Fatalf("Failed to bind socket: %v", err)
}
fmt.Printf("Sending UDP packets with GSO enabled\n")
fmt.Printf("Destination: %s\n", *destAddr)
fmt.Printf("GSO segment size: %d bytes\n", *gsoSize)
fmt.Printf("Total payload size: %d bytes\n", *totalSize)
fmt.Printf("Number of packets: %d\n\n", *count)
// Create payload
payload := make([]byte, *totalSize)
for i := range payload {
payload[i] = byte(i % 256)
}
dest := netip.MustParseAddrPort(*destAddr)
//if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil {
// panic(err)
//}
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
fmt.Printf("now, let's send without the correct ctrl header\n")
time.Sleep(time.Second)
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
}
func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error {
msgs := make([]rawMessage, 0, 1)
iovs := make([]iovec, 0, 1)
names := make([][unix.SizeofSockaddrInet6]byte, 0, 1)
sent := 0
pkts := []BatchPacket{
{
Payload: payload,
Addr: addr,
},
}
for _, pkt := range pkts {
if len(pkt.Payload) == 0 {
sent++
continue
}
msgs = append(msgs, rawMessage{})
iovs = append(iovs, iovec{})
names = append(names, [unix.SizeofSockaddrInet6]byte{})
idx := len(msgs) - 1
msg := &msgs[idx]
iov := &iovs[idx]
name := &names[idx]
setIovecSlice(iov, pkt.Payload)
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
if withHeader {
setRawMessageControl(msg, buildGSOControlMessage(segSize)) //
} else {
setRawMessageControl(msg, nil) //
}
msg.Hdr.Flags = 0
nameLen, err := encodeSockaddr(name[:], pkt.Addr)
if err != nil {
return err
}
msg.Hdr.Name = &name[0]
msg.Hdr.Namelen = nameLen
}
if len(msgs) == 0 {
return errors.New("nothing to write")
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(fd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return nil
}
func buildGSOControlMessage(segSize uint16) []byte {
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
return control
}

85
cmd/gso/helper.go Normal file
View File

@@ -0,0 +1,85 @@
package main
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad0 [4]byte
Iov *iovec
Iovlen uint64
Control *byte
Controllen uint64
Flags int32
Pad1 [4]byte
}
type rawMessage struct {
Hdr msghdr
Len uint32
Pad0 [4]byte
}
type BatchPacket struct {
Payload []byte
Addr netip.AddrPort
}
func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if addr.Addr().Is4() {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}

View File

@@ -65,8 +65,16 @@ func main() {
} }
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
ctrl.ShutdownBlock() if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -3,6 +3,9 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
_ "net/http/pprof"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -58,10 +61,22 @@ func main() {
os.Exit(1) os.Exit(1)
} }
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l) notifyReady(l)
ctrl.ShutdownBlock() wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -356,7 +356,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if cm.shouldSwapPrimary(hostinfo, primary) { if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary decision = swapPrimary
} else { } else {
// migrate the relays to the primary, if in use. // migrate the relays to the primary, if in use.
@@ -447,7 +447,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time
return inactiveDuration, true return inactiveDuration, true
} }
func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.

View File

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{}, addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10), queryChan: make(chan netip.Addr, 10),
} }
lighthouses := map[netip.Addr]struct{}{} lighthouses := []netip.Addr{}
staticList := map[netip.Addr]struct{}{} staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses) lh.lighthouses.Store(&lighthouses)
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey return d.publicKey
} }
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte { func (d *dummyCert) Signature() []byte {
return d.signature return d.signature
} }

View File

@@ -13,7 +13,9 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 1024 // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 8192
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

View File

@@ -2,9 +2,11 @@ package nebula
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
) )
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
} }
type Control struct { type Control struct {
stateLock sync.Mutex
state RunState
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
ctx context.Context ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call.
func (c *Control) Start() { // The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface // Activate the interface
c.f.activate() err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created. // Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil { if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
} }
// Start reading packets. // Start reading packets.
c.f.run() c.state = Started
c.stateLock.Unlock()
return c.f.run(c.ctx)
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
} }
func (c *Control) Context() context.Context { func (c *Control) Context() context.Context {
return c.ctx return c.ctx
} }
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete // Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() { func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from // Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down. // being created while we're shutting them all down.
c.cancel() c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil { if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed") c.l.WithError(err).Error("Close interface failed")
} }
c.l.Info("Goodbye") c.state = Stopped
} }
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32") ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err) require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{} f := &Firewall{}
ft := FirewallTable{ ft := FirewallTable{
@@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
} }
}) })
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
} }
}) })
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) { b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) { b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
@@ -727,6 +909,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}

4
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.23.0 go 1.25
toolchain go1.24.1
require ( require (
dario.cat/mergo v1.0.2 dario.cat/mergo v1.0.2

View File

@@ -459,7 +459,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo) f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return return
} }
@@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
} }
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
return false return false

View File

@@ -17,12 +17,10 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10 const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -225,8 +223,7 @@ type HostInfo struct {
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our // The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite networks *bart.Lite
@@ -733,13 +730,6 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError.Add(1) >= maxRecvError {
return true
}
return true
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 { if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed // Simple case, no CIDRTree needed

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -18,6 +18,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -87,12 +88,17 @@ type Interface struct {
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger l *logrus.Logger
pktPool *packet.Pool
inbound chan *packet.Packet
outbound chan *packet.Packet
} }
type EncWriter interface { type EncWriter interface {
@@ -194,9 +200,15 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
//TODO: configurable size
inbound: make(chan *packet.Packet, 2048),
outbound: make(chan *packet.Packet, 2048),
l: c.l, l: c.l,
} }
ifce.pktPool = packet.GetPool()
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -209,7 +221,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any // activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However, // other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called. // the interface isn't going to process anything until run() is called.
func (f *Interface) activate() { func (f *Interface) activate() error {
// actually turn on tun dev // actually turn on tun dev
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
@@ -230,33 +242,46 @@ func (f *Interface) activate() {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
if err != nil { if err != nil {
f.l.Fatal(err) return err
} }
} }
f.readers[i] = reader f.readers[i] = reader
} }
if err := f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
f.inside.Close() f.inside.Close()
f.l.Fatal(err) return err
} }
return nil
} }
func (f *Interface) run() { func (f *Interface) run(c context.Context) (func(), error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
// read packets from udp and queue to f.inbound
f.wg.Add(1)
go f.listenOut(i) go f.listenOut(i)
// Launch n queues to read packets from inside tun dev and queue to f.outbound
//todo this never stops f.wg.Add(1)
go f.listenIn(f.readers[i], i)
// Launch n workers to process traffic from f.inbound and smash it onto the inside of the tun
f.wg.Add(1)
go f.workerIn(i, c)
f.wg.Add(1)
go f.workerIn(i, c)
// read from f.outbound and write to UDP (outside the tun)
f.wg.Add(1)
go f.workerOut(i, c)
} }
// Launch n queues to read packets from tun dev return f.wg.Wait, nil
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
}
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
@@ -264,41 +289,90 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) err := li.ListenOut(f.pktPool.Get, f.inbound)
lhh := f.lightHouse.NewRequestHandler() if err != nil && !f.closed.Load() {
plaintext := make([]byte, udp.MTU) f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
h := &header.H{} //TODO: Trigger Control to close
fwPacket := &firewall.Packet{} }
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.l.Debugf("underlay reader %v is done", i)
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.wg.Done()
})
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) p := f.pktPool.Get()
n, err := reader.Read(p.Payload)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if !f.closed.Load() {
return f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
} }
break
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) p.Payload = (p.Payload)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket2 := &firewall.Packet{}
nb2 := make([]byte, 12, 12)
result2 := make([]byte, mtu)
h := &header.H{}
for {
select {
case p := <-f.inbound:
if p.SegSize > 0 && p.SegSize < len(p.Payload) {
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
}
} else {
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
}
f.pktPool.Put(p)
case <-ctx.Done():
f.wg.Done()
return
}
}
}
func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12)
result1 := make([]byte, mtu)
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
f.pktPool.Put(data)
case <-ctx.Done():
f.wg.Done()
return
}
} }
} }
@@ -451,6 +525,7 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers { for _, u := range f.writers {
err := u.Close() err := u.Close()
if err != nil { if err != nil {
@@ -458,6 +533,13 @@ func (f *Interface) Close() error {
} }
} }
// Release the tun device // Release the tun readers
return f.inside.Close() for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
} }

View File

@@ -24,6 +24,7 @@ import (
) )
var ErrHostNotKnown = errors.New("host not known") var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct { type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -56,7 +57,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}] staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[[]netip.Addr]
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l, l: l,
} }
lighthouses := make(map[netip.Addr]struct{}) lighthouses := make([]netip.Addr, 0)
h.lighthouses.Store(&lighthouses) h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{}) staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList) h.staticList.Store(&staticList)
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load() return *lh.staticList.Load()
} }
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { func (lh *LightHouse) GetLighthouses() []netip.Addr {
return *lh.lighthouses.Load() return *lh.lighthouses.Load()
} }
@@ -306,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
if initial || c.HasChanged("lighthouse.hosts") { if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[netip.Addr]struct{}) lhList, err := lh.parseLighthouses(c)
err := lh.parseLighthouses(c, lhMap)
if err != nil { if err != nil {
return err return err
} }
lh.lighthouses.Store(&lhMap) lh.lighthouses.Store(&lhList)
if !initial { if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed") lh.l.Info("lighthouse.hosts has changed")
@@ -346,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
lhs := c.GetStringSlice("lighthouse.hosts", []string{}) lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 { if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
} }
out := make([]netip.Addr, len(lhs))
for i, host := range lhs { for i, host := range lhs {
addr, err := netip.ParseAddr(host) addr, err := netip.ParseAddr(host)
if err != nil { if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
} }
lhMap[addr] = struct{}{} out[i] = addr
} }
if !lh.amLighthouse && len(lhMap) == 0 { if !lh.amLighthouse && len(out) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
} }
staticList := lh.GetStaticHostList() staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap { for i := range out {
if _, ok := staticList[lhAddr]; !ok { if _, ok := staticList[out[i]]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
} }
} }
return nil return out, nil
} }
func getStaticMapCadence(c *config.C) (time.Duration, error) { func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -486,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock() lh.Lock()
defer lh.Unlock() defer lh.Unlock()
// Add an entry if we don't already have one // Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
} }
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -519,11 +520,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
} }
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static mapping // First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
// and do nothing if it is there staticList := lh.GetStaticHostList()
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { for _, addr := range allVpnAddrs {
return if _, ok := staticList[addr]; ok {
return
}
} }
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock() lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]] rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok { if ok {
@@ -565,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr) am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() { for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
continue continue
} }
switch { switch {
@@ -627,23 +632,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0 return len(calculatedV4) > 0 || len(calculatedV6) > 0
} }
// unlockedGetRemoteList // unlockedGetRemoteList assumes you have the lh lock
// assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
am, ok := lh.addrMap[allAddrs[0]] // before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
if !ok { for i, addr := range allAddrs {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) am, ok := lh.addrMap[addr]
for _, addr := range allAddrs { if ok {
lh.addrMap[addr] = am if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
} }
} }
am := NewRemoteList(allAddrs, lh.shouldAdd)
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
return am return am
} }
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow") Trace("remoteAllowList.Allow")
} }
if !allow { if !allow {
@@ -698,19 +710,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
} }
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnAddr]; ok { l := lh.GetLighthouses()
return true for i := range l {
if l[i] == vpnAddr {
return true
}
} }
return false return false
} }
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
l := lh.GetLighthouses() l := lh.GetLighthouses()
for _, a := range vpnAddr { for i := range vpnAddrs {
if _, ok := l[a]; ok { for j := range l {
return true if l[j] == vpnAddrs[i] {
return true
}
} }
} }
return false return false
@@ -752,7 +767,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0 queried := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
v = hi.ConnectionState.myCert.Version() v = hi.ConnectionState.myCert.Version()
@@ -870,7 +885,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0 updated := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
var v cert.Version var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
@@ -928,7 +943,6 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4, V4AddrPorts: v4,
V6AddrPorts: v6, V6AddrPorts: v6,
RelayVpnAddrs: relays, RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
}, },
} }
@@ -1048,19 +1062,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return return
} }
useVersion := cert.Version1 queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
var queryVpnAddr netip.Addr if err != nil {
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
} }
return return
} }
@@ -1069,9 +1083,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 { if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4() b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else { } else {
@@ -1116,8 +1127,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok { if ok {
whereToPunch = newDest whereToPunch = newDest
} else { } else {
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee if lhh.l.Level >= logrus.DebugLevel {
//choosing to do nothing for now, but maybe we return an error? lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
} }
} }
@@ -1176,19 +1188,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() { if !r.Is4() {
continue continue
} }
b = r.As4() b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
} }
} else if v == cert.Version2 { } else if v == cert.Version2 {
for _, r := range c.relay.relay { for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
} }
} else { } else {
//TODO: CERT-V2 don't panic if lhh.l.Level >= logrus.DebugLevel {
panic("unsupported version") lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
} }
} }
} }
@@ -1198,18 +1208,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return return
} }
lhh.lh.Lock() certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
var certVpnAddr netip.Addr if lhh.l.Level >= logrus.DebugLevel {
if n.Details.OldVpnAddr != 0 { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
b := [4]byte{} }
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) return
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
} }
relays := n.Details.GetRelays() relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
@@ -1234,27 +1242,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return return
} }
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr var detailsVpnAddr netip.Addr
useVersion := cert.Version1 var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { if n.Details.OldVpnAddr != 0 { //v1 always sets this field
b := [4]byte{} b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b) detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1 useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { } else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2 useVersion = cert.Version2
} else { } else {
if lhh.l.Level >= logrus.DebugLevel { detailsVpnAddr = netip.Addr{}
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") useVersion = cert.Version2
}
return
} }
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? //Simple check that the host sent this not someone else, if detailsVpnAddr is filled
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
} }
@@ -1268,24 +1273,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays) am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock() am.Unlock()
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
if useVersion == cert.Version1 { case cert.Version1:
if !fromVpnAddrs[0].Is4() { if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return return
} }
vpnAddrB := fromVpnAddrs[0].As4() vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
} else if useVersion == cert.Version2 { case cert.Version2:
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) // do nothing, we want to send a blank message
} else { default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return return
} }
@@ -1303,13 +1308,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr, //It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return return
} }
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0} empty := []byte{0}
punch := func(vpnPeer netip.AddrPort) { punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() { if !vpnPeer.IsValid() {
return return
} }
@@ -1321,48 +1333,31 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}() }()
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
} }
} }
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a)) punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a)) punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish // of a double nat or other difficult scenario, this may help establish
// a tunnel. // a tunnel.
if lhh.lh.punchy.GetRespond() { if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() { go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay()) time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
} }
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}() }()
} }
} }
@@ -1441,3 +1436,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
} }
return netip.Addr{}, false return netip.Addr{}, false
} }
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
} }
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

18
main.go
View File

@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
return &Control{ return &Control{
ifce, f: ifce,
l, l: l,
ctx, ctx: ctx,
cancel, cancel: cancel,
sshStart, sshStart: sshStart,
statsStart, statsStart: statsStart,
dnsStart, dnsStart: dnsStart,
lightHouse.StartUpdateWorker, lighthouseStart: lightHouse.StartUpdateWorker,
connManager.Start, connectionManagerStart: connManager.Start,
}, nil }, nil
} }

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //f.l.Error("in packet ", h)
if ip.IsValid() { if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
return return
} }
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
@@ -254,16 +255,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
} }
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil {
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr.IsValid() { if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex) f.maybeSendRecvError(addr, h.RemoteIndex)
return false
} else {
return false
} }
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
} }
return true return true
@@ -468,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
return false return false
} }
@@ -537,10 +540,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return return
} }
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr { if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return return

View File

@@ -1,6 +1,7 @@
package overlay package overlay
import ( import (
"net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed return removed
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -554,13 +553,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -10,11 +10,9 @@ import (
"io" "io"
"io/fs" "io/fs"
"net/netip" "net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
) )
const ( const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD // FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678 FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
) )
type fiodgnameArg struct { type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
} }
type ifreqRename struct { type ifreqRename struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
} }
type ifreqDestroy struct { type ifreqDestroy struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
pad [16]byte pad [16]byte
} }
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct { type tun struct {
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
devFd int
}
io.ReadWriteCloser func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.ReadWriteCloser != nil { if t.devFd >= 0 {
if err := t.ReadWriteCloser.Close(); err != nil { err := syscall.Close(t.devFd)
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil { if err != nil {
return err t.l.WithError(err).Error("Error closing device")
} }
defer syscall.Close(s) t.devFd = -1
ifreq := ifreqDestroy{Name: t.deviceBytes()} c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// Destroy the interface // wait up to 1 second so we start blocking at the ioctl
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) select {
return err case <-c:
case <-time.After(1 * time.Second):
}
} }
return nil return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var fd int
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName != "" { if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
} }
if errors.Is(err, fs.ErrNotExist) || deviceName == "" { if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it // If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
rawConn, err := file.SyscallConn() // Read the name of the interface
if err != nil { var name [16]byte
return nil, fmt.Errorf("SyscallConn: %v", err) arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
} }
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil { if ctrlErr != nil {
return nil, err return nil, err
} }
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now // If the name doesn't match the desired interface name, rename it now
if ifName != deviceName { if ifName != deviceName {
s, err := syscall.Socket( s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
t := &tun{ t := &tun{
ReadWriteCloser: file, Device: deviceName,
Device: deviceName, vpnNetworks: vpnNetworks,
vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU),
MTU: c.GetInt("tun.mtu", DefaultMTU), l: l,
l: l, devFd: fd,
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error if cidr.Addr().Is4() {
// TODO use syscalls instead of exec.Command ifr := ifreqAlias4{
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) Name: t.deviceBytes(),
t.l.Debug("command: ", cmd.String()) Addr: unix.RawSockaddrInet4{
if err = cmd.Run(); err != nil { Len: unix.SizeofSockaddrInet4,
return fmt.Errorf("failed to run 'ifconfig': %s", err) Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) if cidr.Addr().Is6() {
t.l.Debug("command: ", cmd.String()) ifr := ifreqAlias6{
if err = cmd.Run(); err != nil { Name: t.deviceBytes(),
return fmt.Errorf("failed to run 'route add': %s", err) Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) return fmt.Errorf("unknown address type %v", cidr)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i]) err := t.addIp(t.vpnNetworks[i])
if err != nil { if err != nil {
return err return err
} }
} }
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) err := addRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
return retErr return retErr
} }
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) err := delRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +500,144 @@ func (t *tun) deviceBytes() (o [16]byte) {
} }
return return
} }
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -293,7 +293,6 @@ func (t *tun) addIPs(link netlink.Link) error {
//add all new addresses //add all new addresses
for i := range newAddrs { for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well //AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err return err
@@ -361,6 +360,11 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
if err = t.addIPs(link); err != nil { if err = t.addIPs(link); err != nil {
return err return err
} }
@@ -638,6 +642,11 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return return
} }
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok { if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")

45
packet/packet.go Normal file
View File

@@ -0,0 +1,45 @@
package packet
import (
"net/netip"
"sync"
"golang.org/x/sys/unix"
)
const Size = 0xffff
type Packet struct {
Payload []byte
Control []byte
SegSize int
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
}
}
type Pool struct {
pool sync.Pool
}
var bigPool = &Pool{
pool: sync.Pool{New: func() any { return New() }},
}
func GetPool() *Pool {
return bigPool
}
func (p *Pool) Get() *Packet {
return p.pool.Get().(*Packet)
}
func (p *Pool) Put(x *Packet) {
x.Payload = x.Payload[:Size]
p.pool.Put(x)
}

View File

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
} }
// Set up the parameters which include the peer's public key // Set up the parameters which include the peer's public key

5
pki.go
View File

@@ -173,7 +173,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState) p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial { if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else { } else {
@@ -359,7 +358,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
} }
//TODO: CERT-V2 make sure v2 has v1s address if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
cs.initiatingVersion = dv cs.initiatingVersion = dv
} }

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host // The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of underlay addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays. // A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,8 +201,10 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet // For learned addresses, this is the vpnIp that sent the packet
cache map[netip.Addr]*cache cache map[netip.Addr]*cache
hr *hostnamesResults hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
r := &RemoteList{ r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)), vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0), addrs: make([]netip.AddrPort, 0),
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c return c
} }
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list // ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() { func (r *RemoteList) ResetBlockedRemotes() {
r.Lock() r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs() dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if !r.unlockedIsBad(addr) { if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr) addrs = append(addrs, addr)
} }

View File

@@ -9,10 +9,13 @@ import (
"math" "math"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"sync" "sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -43,8 +46,19 @@ type Service struct {
} }
} }
func New(control *nebula.Control) (*Service, error) { func New(config *config.C) (*Service, error) {
control.Start() logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context() ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +155,12 @@ func New(control *nebula.Control) (*Service, error) {
} }
}) })
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil return &s, nil
} }

View File

@@ -4,19 +4,19 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(*packet.Packet)
addr netip.AddrPort,
payload []byte, type PacketBufferGetter func() *packet.Packet
)
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error

View File

@@ -71,15 +71,14 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])

View File

@@ -9,27 +9,26 @@ import (
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type StdConn struct { var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
sysFd int
isV4 bool
l *logrus.Logger
batch int
}
func maybeIPV4(ip net.IP) (net.IP, bool) { type StdConn struct {
ip4 := ip.To4() sysFd int
if ip4 != nil { isV4 bool
return ip4, true l *logrus.Logger
} batch int
return ip, false enableGRO bool
enableGSO bool
//gso gsoState
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -55,6 +54,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
} }
} }
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr var sa unix.Sockaddr
if ip.Is4() { if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port} sa4 := &unix.SockaddrInet4{Port: port}
@@ -118,10 +122,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error {
var ip netip.Addr var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch) msgs, packets, names := u.PrepareRawMessages(u.batch, pg)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
@@ -130,22 +134,61 @@ func (u *StdConn) ListenOut(r EncReader) {
for { for {
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
out := packets[i]
out.Payload = out.Payload[:msgs[i].Len]
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 { if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8]) ip, _ = netip.AddrFromSlice(names[i][4:8])
} else { } else {
ip, _ = netip.AddrFromSlice(names[i][8:24]) ip, _ = netip.AddrFromSlice(names[i][8:24])
} }
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
ctrlLen := getRawMessageControlLen(&msgs[i])
if ctrlLen > 0 {
packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen])
} else {
packets[i].SegSize = 0
}
pc <- out
//rotate this packet out so we don't overwrite it
packets[i] = pg()
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control))
}
} }
} }
} }
func parseGROControl(control []byte) int {
if len(control) == 0 {
return 0
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
return segSize
}
}
return 0
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
@@ -159,6 +202,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err} return 0, &net.OpError{Op: "recvmsg", Err: err}
} }
@@ -180,6 +226,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err} return 0, &net.OpError{Op: "recvmmsg", Err: err}
} }
@@ -221,7 +270,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() { if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet4
@@ -294,6 +343,28 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark") u.l.WithError(err).Error("Failed to set listen.so_mark")
} }
} }
u.configureGRO(true)
}
func (u *StdConn) configureGRO(enable bool) {
if enable == u.enableGRO {
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.l.Info("UDP GRO enabled")
return
}
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
} }
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {

View File

@@ -7,6 +7,7 @@
package udp package udp
import ( import (
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -33,17 +34,39 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
packets := make([]*packet.Packet, n)
for i := range packets {
packets[i] = pg()
}
//todo?
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, {Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
} }
msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iov = &vs[0]
@@ -51,7 +74,14 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = 0
}
} }
return msgs, buffers, names return msgs, packets, names
} }

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader) { func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {