Compare commits

...

43 Commits

Author SHA1 Message Date
Jay Wren
be90e4aa05 handle virtio header in ctrl messages 2025-11-19 17:09:39 -05:00
Jay Wren
bc9711df68 batch more writes 2025-11-19 16:58:59 -05:00
Jay Wren
4e333c76ba write batching 2025-11-19 14:03:36 -05:00
Jay Wren
f29e21b411 don't register metrics in loops 2025-11-19 13:25:25 -05:00
Jay Wren
8b32382cd9 zero copy even with virtioheder 2025-11-19 12:03:38 -05:00
Jay Wren
518a78c9d2 preallocate nonce buffer 2025-11-18 14:19:05 -05:00
Jay Wren
7c3708561d instruments 2025-11-14 14:43:51 -05:00
Jay Wren
a62ffca975 fix 32bit 2025-11-13 15:10:51 -05:00
Jay Wren
226787ea1f prealloc them buffers 2025-11-11 15:20:50 -05:00
Jay Wren
b2bc6a09ca write in batches 2025-11-11 15:06:45 -05:00
Jay Wren
0f9b33aa36 reduce copying 2025-11-11 14:51:53 -05:00
Jay Wren
ef0a022375 more nonblocking 2025-11-11 14:22:40 -05:00
Jay Wren
b68e504865 hrm 2025-11-11 13:15:30 -05:00
Jay Wren
3344a840d1 just using the wg library works 2025-11-11 10:55:39 -05:00
Jay Wren
2bc9863e66 only wg tun, no batching 2025-11-10 16:54:00 -05:00
Wade Simmons
97b3972c11 honor remote_allow_list in hole punch response (#1186)
* honor remote_allow_ilst in hole punch response

When we receive a "hole punch notification" from a Lighthouse, we send
a hole punch packet to every remote of that host, even if we don't
include those remotes in our "remote_allow_list". Change the logic here
to check if the remote IP is in our allow list before sending the hole
punch packet.

* fix for netip

* cleanup
2025-11-10 13:52:40 -05:00
Jack Doan
0f305d5397 don't block startup on failure to configure SSH (#1520) 2025-11-05 10:41:56 -06:00
Jack Doan
01909f4715 try to make certificate addition/removal reloadable in some cases (#1468)
* try to make certificate addition/removal reloadable in some cases

* very spicy change to respond to handshakes with cert versions we cannot match with a cert that we can indeed match

* even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

* make tryRehandshake easier to understand
2025-11-03 19:38:44 -06:00
Jack Doan
770147264d fix make bench (#1510) 2025-10-21 11:32:34 -05:00
dependabot[bot]
fa8c013b97 Bump github.com/miekg/dns from 1.1.65 to 1.1.68 (#1444)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.65 to 1.1.68.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.65...v1.1.68)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.68
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 16:41:51 -04:00
dependabot[bot]
2710f2af06 Bump github.com/kardianos/service from 1.2.2 to 1.2.4 (#1433)
Bumps [github.com/kardianos/service](https://github.com/kardianos/service) from 1.2.2 to 1.2.4.
- [Commits](https://github.com/kardianos/service/compare/v1.2.2...v1.2.4)

---
updated-dependencies:
- dependency-name: github.com/kardianos/service
  dependency-version: 1.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:58:15 -04:00
dependabot[bot]
ad6d3e6bac Bump the golang-x-dependencies group across 1 directory with 5 updates (#1409)
Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync).


Updates `golang.org/x/crypto` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/crypto/compare/v0.37.0...v0.38.0)

Updates `golang.org/x/net` from 0.39.0 to 0.40.0
- [Commits](https://github.com/golang/net/compare/v0.39.0...v0.40.0)

Updates `golang.org/x/sync` from 0.13.0 to 0.14.0
- [Commits](https://github.com/golang/sync/compare/v0.13.0...v0.14.0)

Updates `golang.org/x/sys` from 0.32.0 to 0.33.0
- [Commits](https://github.com/golang/sys/compare/v0.32.0...v0.33.0)

Updates `golang.org/x/term` from 0.31.0 to 0.32.0
- [Commits](https://github.com/golang/term/compare/v0.31.0...v0.32.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.38.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-version: 0.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-version: 0.14.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.33.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.32.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:54:38 -04:00
dependabot[bot]
2b0aa74e85 Bump github.com/prometheus/client_golang from 1.22.0 to 1.23.2 (#1470)
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.22.0 to 1.23.2.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.22.0...v1.23.2)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-version: 1.23.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:16:24 -04:00
dependabot[bot]
b126d88963 Bump github.com/gaissmai/bart from 0.20.4 to 0.25.0 (#1471)
Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.20.4 to 0.25.0.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.20.4...v0.25.0)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.25.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:15:07 -04:00
Nate Brown
45c1d3eab3 Support for multi proto tun device on OpenBSD (#1495) 2025-10-08 16:56:42 -05:00
Gary Guo
634181ba66 Fix incorrect CIDR construction in hostmap (#1493)
* Fix incorrect CIDR construction in hostmap

* Introduce a regression test for incorrect hostmap CIDR
2025-10-08 11:02:36 -05:00
Nate Brown
eb89839d13 Support for multi proto tun device on NetBSD (#1492) 2025-10-07 20:17:50 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
dependabot[bot]
071589f7c7 Bump actions/setup-go from 5 to 6 (#1469)
* Bump actions/setup-go from 5 to 6

Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Hardcode the last one to go v1.25

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-10-02 00:05:12 -05:00
Jack Doan
f1e992f6dd don't require a detailsVpnAddr in a HostUpdateNotification (#1472)
* don't require a detailsVpnAddr in a HostUpdateNotification

* don't send our own addr on HostUpdateNotification for v2
2025-09-29 13:43:12 -05:00
Jack Doan
1ea5f776d7 update to go 1.25, use the cool new ECDSA key marshalling functions (#1483)
* update to go 1.25, use the cool new ECDSA key marshalling functions

* bonk the runners

* actually bump go.mod

* bump golangci-lint
2025-09-29 13:02:25 -05:00
Henry Graham
4cdeb284ef Set CKA_VALUE_LEN attribute in DeriveNoise (#1482) 2025-09-25 13:24:52 -05:00
Jack Doan
5cccd39465 update RemoteList.vpnAddrs when we complete a handshake (#1467) 2025-09-10 09:44:25 -05:00
Jack Doan
8196c22b5a store lighthouses as a slice (#1473)
* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
2025-09-10 09:43:25 -05:00
Jack Doan
65cc253c19 prevent linux from assigning ipv6 link-local addresses (#1476) 2025-09-09 13:25:23 -05:00
Wade Simmons
73cfa7b5b1 add firewall tests for ipv6 (#1451)
Test things like cidr and local_cidr with ipv6 addresses, to ensure
everything is working correctly.
2025-09-08 13:57:36 -04:00
Jack Doan
768325c9b4 cert-v2 chores (#1466) 2025-09-05 15:08:22 -05:00
Jack Doan
932e329164 Don't delete static host mappings for non-primary IPs (#1464)
* Don't delete a vpnaddr if it's part of a certificate that contains a vpnaddr that's in the static host map

* remove unused arg from ConnectionManager.shouldSwapPrimary()
2025-09-04 14:49:40 -05:00
Jack Doan
4bea299265 don't send recv errors for packets outside the connection window anymore (#1463)
* don't send recv errors for packets outside the connection window anymore

* Pull in fix from #1459, add my opinion on maxRecvError

* remove recv_error counter entirely
2025-09-03 11:52:52 -05:00
Wade Simmons
5cff83b282 netlink: ignore route updates with no destination (#1437)
Currently we assume each route update must have a destination, but we
should check that it is set before we try to use it.

See: #1436
2025-08-25 13:05:35 -05:00
Wade Simmons
7da79685ff fix lighthouse.calculated_remotes parsing (#1438)
Some checks failed
gofmt / Run gofmt (push) Successful in 27s
smoke-extra / Run extra smoke tests (push) Failing after 21s
smoke / Run multi node smoke test (push) Failing after 1m21s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m9s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m16s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m41s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
This was broken with the change to yaml.v3:

- https://github.com/slackhq/nebula/pull/1148

We forgot to update these references to `map[string]any`.

Without this fix, Nebula crashes with an error like this:

    {"error":"config `lighthouse.calculated_remotes` has invalid type: map[string]interface {}","level":"error","msg":"Invalid lighthouse.calculated_remotes","time":"2025-07-29T15:50:06.479499Z"}
2025-07-29 13:12:07 -04:00
56 changed files with 3617 additions and 673 deletions

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -35,9 +35,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -68,9 +68,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version-file: 'go.mod'
go-version: '1.25'
check-latest: true
- name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -34,7 +34,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
with:
version: v2.1
version: v2.5
- name: Test
run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.22'
go-version: '1.25'
check-latest: true
- name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build nebula
@@ -117,7 +117,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
with:
version: v2.1
version: v2.5
- name: Test
run: make test

View File

@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any)
rawMap, ok := value.(map[string]any)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
for rawCIDR, rawValue := range rawMap {
cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
}
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any)
rawMap, ok := raw.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
}

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve
@@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default:
//TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
return nil, ErrUnknownVersion
}
if err != nil {

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -1,6 +1,7 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -13,6 +14,7 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -15,6 +15,7 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{

View File

@@ -20,6 +20,7 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -7,19 +7,26 @@ import (
"golang.org/x/crypto/ed25519"
)
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner:
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
if err != nil {
return nil, err
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
}
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -354,9 +354,8 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo, primary) {
if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
@@ -447,7 +446,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool
}
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Unlock()
}
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true.
// isInvalidCertificate decides if we should destroy a tunnel.
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert()
if remoteCert == nil {
return false
return false //don't tear down tunnels for handshakes in progress
}
caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false
}
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
}

View File

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10),
}
lighthouses := map[netip.Addr]struct{}{}
lighthouses := []netip.Addr{}
staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -129,6 +129,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
return control, vpnNetworks, udpAddr, c
}
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb {

View File

@@ -4,12 +4,16 @@
package e2e
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestDropInactiveTunnels(t *testing.T) {
@@ -55,3 +59,262 @@ func TestDropInactiveTunnels(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{
@@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
@@ -510,6 +692,50 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
},
}
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.0.2.1"),
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ {
@@ -727,6 +953,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}

37
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/slackhq/nebula
go 1.23.0
toolchain go1.24.1
go 1.25
require (
dario.cat/mergo v1.0.2
@@ -10,30 +8,30 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.20.4
github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.68
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.22.0
github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.37.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.39.0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0
golang.org/x/net v0.45.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.6
google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
@@ -45,11 +43,12 @@ require (
github.com/google/btree v1.1.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/mod v0.23.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.30.0 // indirect
golang.org/x/tools v0.33.0 // indirect
)

69
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -143,29 +143,33 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -176,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -185,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -197,18 +201,17 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -219,8 +222,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -239,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -23,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.initiatingVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
}
@@ -48,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -103,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -143,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp = "<error generating certificate fingerprint>"
}
@@ -163,16 +169,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"udpAddr": addr,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
}
} else {
// Record the certificate we are actually using
ci.myCert = myCertOtherVersion
}
// Record the certificate we are actually using
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 {
@@ -459,7 +468,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.ResetBlockedRemotes()
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return
}
@@ -667,7 +676,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.ResetBlockedRemotes()
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
return false

View File

@@ -68,11 +68,12 @@ type HandshakeManager struct {
type HandshakeHostInfo struct {
sync.Mutex
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
}

View File

@@ -17,12 +17,10 @@ import (
"github.com/slackhq/nebula/header"
)
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -225,8 +223,7 @@ type HostInfo struct {
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32
vpnAddrs []netip.Addr
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
@@ -733,13 +730,6 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError.Add(1) >= maxRecvError {
return true
}
return true
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
@@ -748,7 +738,8 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
i.networks = new(bart.Lite)
for _, network := range networks {
i.networks.Insert(network)
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
i.networks.Insert(nprefix)
}
for _, network := range unsafeNetworks {

143
inside.go
View File

@@ -11,6 +11,149 @@ import (
"github.com/slackhq/nebula/routing"
)
// consumeInsidePackets processes multiple packets in a batch for improved performance
// packets: slice of packet buffers to process
// sizes: slice of packet sizes
// count: number of packets to process
// outs: slice of output buffers (one per packet) with virtio headroom
// q: queue index
// localCache: firewall conntrack cache
// batchPackets: pre-allocated slice for accumulating encrypted packets
// batchAddrs: pre-allocated slice for accumulating destination addresses
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
// Reusable per-packet state
fwPacket := &firewall.Packet{}
// Reset batch accumulation slices (reuse capacity)
*batchPackets = (*batchPackets)[:0]
*batchAddrs = (*batchAddrs)[:0]
// Process each packet in the batch
for i := 0; i < count; i++ {
packet := packets[i][:sizes[i]]
out := outs[i]
// Inline the consumeInsidePacket logic for better performance
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
continue
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
continue
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
// Immediately forward packets from self to self.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
continue
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
continue
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
continue
}
if !ready {
continue
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
continue
}
// Encrypt and prepare packet for batch sending
ci := hostinfo.ConnectionState
if ci.eKey == nil {
continue
}
// Check if this needs relay - if so, send immediately and skip batching
useRelay := !hostinfo.remote.IsValid()
if useRelay {
// Handle relay sends individually (less common path)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
continue
}
// Encrypt the packet for batch sending
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to batch
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
// Send all accumulated packets in one batch
if len(*batchPackets) > 0 {
batchSize := len(*batchPackets)
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
}
}
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"runtime"
@@ -22,6 +21,7 @@ import (
)
const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct {
HostMap *HostMap
@@ -50,6 +50,13 @@ type InterfaceConfig struct {
l *logrus.Logger
}
type batchMetrics struct {
udpReadSize metrics.Histogram
tunReadSize metrics.Histogram
udpWriteSize metrics.Histogram
tunWriteSize metrics.Histogram
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
@@ -86,11 +93,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []io.ReadWriteCloser
readers []overlay.BatchReadWriter
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
batchMetrics *batchMetrics
l *logrus.Logger
}
@@ -177,7 +185,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
readers: make([]overlay.BatchReadWriter, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
@@ -193,6 +201,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
batchMetrics: &batchMetrics{
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
},
l: c.l,
}
@@ -225,7 +239,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
var reader overlay.BatchReadWriter = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
@@ -266,39 +280,69 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
// Pre-allocate output buffers for batch processing
batchSize := li.BatchSize()
outs := make([][]byte, batchSize)
for idx := range outs {
// Allocate full buffer with virtio header space
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
}
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
batchSize := reader.BatchSize()
// Allocate buffers for batch reading
bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Allocate output buffers for batch processing (one per packet)
// Each has virtio header headroom to avoid copies on write
outs := make([][]byte, batchSize)
for idx := range outs {
outBuf := make([]byte, virtioNetHdrLen+mtu)
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
}
// Pre-allocate batch accumulation buffers for sending
batchPackets := make([][]byte, 0, batchSize)
batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
n, err := reader.BatchRead(bufs, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
f.l.WithError(err).Error("Error while reading outbound packet")
f.l.WithError(err).Error("Error while batch reading outbound packets")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.batchMetrics.tunReadSize.Update(int64(n))
// Process all packets in the batch at once
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
}
}

View File

@@ -24,6 +24,7 @@ import (
)
var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -56,7 +57,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[[]netip.Addr]
interval atomic.Int64
updateCancel context.CancelFunc
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make(map[netip.Addr]struct{})
lighthouses := make([]netip.Addr, 0)
h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
func (lh *LightHouse) GetLighthouses() []netip.Addr {
return *lh.lighthouses.Load()
}
@@ -306,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lhMap)
lhList, err := lh.parseLighthouses(c)
if err != nil {
return err
}
lh.lighthouses.Store(&lhMap)
lh.lighthouses.Store(&lhList)
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
@@ -346,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
out := make([]netip.Addr, len(lhs))
for i, host := range lhs {
addr, err := netip.ParseAddr(host)
if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(addr) {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
lhMap[addr] = struct{}{}
out[i] = addr
}
if !lh.amLighthouse && len(lhMap) == 0 {
if !lh.amLighthouse && len(out) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap {
if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
for i := range out {
if _, ok := staticList[out[i]]; !ok {
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
}
}
return nil
return out, nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -486,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs)
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -519,11 +520,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
}
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
return
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
staticList := lh.GetStaticHostList()
for _, addr := range allVpnAddrs {
if _, ok := staticList[addr]; ok {
return
}
}
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
@@ -565,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
continue
}
switch {
@@ -627,23 +632,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0
}
// unlockedGetRemoteList
// assumes you have the lh lock
// unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
am, ok := lh.addrMap[allAddrs[0]]
if !ok {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
for _, addr := range allAddrs {
lh.addrMap[addr] = am
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
for i, addr := range allAddrs {
am, ok := lh.addrMap[addr]
if ok {
if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
}
}
am := NewRemoteList(allAddrs, lh.shouldAdd)
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
return am
}
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
}
if !allow {
@@ -698,19 +710,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
}
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
return true
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true
}
}
return false
}
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
l := lh.GetLighthouses()
for _, a := range vpnAddr {
if _, ok := l[a]; ok {
return true
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
return true
}
}
}
return false
@@ -752,7 +767,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0
lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses {
for _, lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
@@ -870,7 +885,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0
lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses {
for _, lhVpnAddr := range lighthouses {
var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
@@ -928,7 +943,6 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4,
V6AddrPorts: v6,
RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
},
}
@@ -1048,19 +1062,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}
useVersion := cert.Version1
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
}
return
}
@@ -1069,9 +1083,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else {
@@ -1116,8 +1127,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
}
}
@@ -1176,19 +1188,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() {
continue
}
b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
}
} else if v == cert.Version2 {
for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
//TODO: CERT-V2 don't panic
panic("unsupported version")
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
}
}
}
@@ -1198,18 +1208,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return
}
lhh.lh.Lock()
var certVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
}
return
}
relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock()
lhh.lh.Unlock()
@@ -1234,27 +1242,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return
}
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr
useVersion := cert.Version1
if n.Details.OldVpnAddr != 0 {
var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil {
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
}
return
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
}
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
}
@@ -1268,24 +1273,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
if useVersion == cert.Version1 {
switch useVersion {
case cert.Version1:
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
} else {
case cert.Version2:
// do nothing, we want to send a blank message
default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return
}
@@ -1303,13 +1308,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort) {
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() {
return
}
@@ -1321,48 +1333,38 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}()
if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a))
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
}
for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a))
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
@@ -1441,3 +1443,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
}
return netip.Addr{}, false
}
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

View File

@@ -75,7 +75,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
sshStart = nil
}
}
@@ -164,7 +165,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}

View File

@@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -138,7 +137,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
// Fallthrough to the bottom to record incoming traffic
@@ -160,7 +159,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
}
// Fallthrough to the bottom to record incoming traffic
@@ -203,7 +202,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -254,16 +253,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
return false
} else {
return false
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
@@ -472,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
err = newPacket(out, true, fwPacket)
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
Warnf("Error while validating inbound packet")
return false
}
@@ -489,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
@@ -537,10 +540,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return
}
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
@@ -550,3 +549,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
// and writes all successfully decrypted packets to TUN in a single operation
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
// Pre-allocate slice for accumulating successful decryptions
tunPackets := make([][]byte, 0, count)
for i := 0; i < count; i++ {
payload := payloads[i]
addr := addrs[i]
out := outs[i]
// Parse header
err := h.Parse(payload)
if err != nil {
if len(payload) > 1 {
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
continue
}
if addr.IsValid() {
if f.myVpnNetworksTable.Contains(addr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
}
continue
}
}
var hostinfo *HostInfo
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, addr, h) {
continue
}
switch h.Subtype {
case header.MessageNone:
// Decrypt packet
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
continue
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
continue
}
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
continue
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
}
continue
}
f.connectionManager.In(hostinfo)
// Add to batch for TUN write
tunPackets = append(tunPackets, out)
case header.MessageRelay:
// Skip relay packets in batch mode for now (less common path)
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
default:
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
}
default:
// Handle non-Message types using single-packet path
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
}
}
if len(tunPackets) > 0 {
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
}
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
}
}

View File

@@ -7,11 +7,25 @@ import (
"github.com/slackhq/nebula/routing"
)
type Device interface {
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
type BatchReadWriter interface {
io.ReadWriteCloser
// BatchRead reads multiple packets at once
BatchRead(bufs [][]byte, sizes []int) (int, error)
// WriteBatch writes multiple packets at once
WriteBatch(bufs [][]byte, offset int) (int, error)
// BatchSize returns the optimal batch size for this device
BatchSize() int
}
type Device interface {
BatchReadWriter
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error)
NewMultiQueueReader() (BatchReadWriter, error)
}

View File

@@ -1,6 +1,8 @@
package overlay
import (
"fmt"
"net"
"net/netip"
"github.com/sirupsen/logrus"
@@ -9,6 +11,7 @@ import (
)
const DefaultMTU = 1300
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
@@ -70,3 +73,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gateway := range gateways {
if dest.Addr().Is4() && gateway.Addr().Is4() {
return gateway, nil
}
if dest.Addr().Is6() && gateway.Addr().Is6() {
return gateway, nil
}
}
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}

View File

@@ -95,6 +95,29 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
@@ -551,16 +549,32 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for non-Linux platforms)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for non-Linux platforms (no batching)
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -105,10 +105,36 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
return t, nil
}
// BatchRead reads a single packet (batch size 1 for disabled tun)
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for disabled tun)
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for disabled tun (no batching)
func (t *disabledTun) BatchSize() int {
return 1
}
func (t *disabledTun) Close() error {
if t.read != nil {
close(t.read)

View File

@@ -10,11 +10,9 @@ import (
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
)
type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
}
type ifreqRename struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
Data uintptr
}
type ifreqDestroy struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
pad [16]byte
}
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
io.ReadWriteCloser
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil {
return err
t.l.WithError(err).Error("Error closing device")
}
defer syscall.Close(s)
t.devFd = -1
ifreq := ifreqDestroy{Name: t.deviceBytes()}
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
}
return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var fd int
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
}
if err != nil {
return nil, err
}
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
// Read the name of the interface
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, err
}
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
}
err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
if cidr.Addr().Is4() {
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
if cidr.Addr().Is6() {
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -256,10 +450,36 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
// BatchRead reads a single packet (batch size 1 for FreeBSD)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for FreeBSD)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for FreeBSD (no batching)
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
@@ -268,15 +488,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -289,9 +510,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +526,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -151,6 +151,29 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -9,7 +9,6 @@ import (
"net"
"net/netip"
"os"
"strings"
"sync/atomic"
"time"
"unsafe"
@@ -21,10 +20,12 @@ import (
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
wgDevice wgtun.Device
fd int
Device string
vpnNetworks []netip.Prefix
@@ -65,59 +66,154 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
// This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct {
dev wgtun.Device
buf []byte // Reusable buffer for single packet reads
}
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
// Use wireguard Device's batch API for single packet
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := w.dev.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
// Pass all buffers to WireGuard's batch write
return w.dev.Write(bufs, offset)
}
func (w *wgDeviceWrapper) Close() error {
return w.dev.Close()
}
// BatchRead implements batching for multiqueue readers
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// The zero here is offset.
return w.dev.Read(bufs, sizes, 0)
}
// BatchSize returns the optimal batch size
func (w *wgDeviceWrapper) BatchSize() int {
return w.dev.BatchSize()
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
}
file := wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.Device = "tun0"
t.wgDevice = wgDev
t.Device = name
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
// Check if /dev/net/tun exists, create if needed (for docker containers)
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
if err := os.MkdirAll("/dev/net", 0755); err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
devName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
// Create TUN device manually to support multiqueue
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], devName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
}
name, err := wgDev.Name()
if err != nil {
_ = wgDev.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
// file is now owned by wgDev, get a new reference
file = wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.wgDevice = wgDev
t.Device = name
return t, nil
@@ -216,22 +312,44 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
// MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking mode - CRITICAL for proper netpoller integration
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Get MTU from main device
mtu := t.MaxMTU
if mtu == 0 {
mtu = DefaultMTU
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
// Create wireguard Device from the file descriptor (just like the main device)
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
}
// Return a wrapper that uses the wireguard Device for all I/O
return &wgDeviceWrapper{dev: wgDev}, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -239,7 +357,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int
maximum := len(b)
@@ -262,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// WriteBatch writes multiple packets to the TUN device in a single syscall
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Write(bufs, offset)
}
// Fallback: write individually (shouldn't happen in normal operation)
for i, buf := range bufs {
_, err := t.Write(buf)
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -293,7 +488,6 @@ func (t *tun) addIPs(link netlink.Link) error {
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
@@ -361,6 +555,11 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
if err = t.addIPs(link); err != nil {
return err
}
@@ -638,6 +837,11 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
@@ -665,6 +869,10 @@ func (t *tun) Close() error {
close(t.routeChan)
}
if t.wgDevice != nil {
_ = t.wgDevice.Close()
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}

View File

@@ -4,13 +4,12 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
@@ -20,11 +19,42 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
const (
SIOCAIFADDR_IN6 = 0x8080696b
TUNSIFHEAD = 0x80047442
TUNSIFMODE = 0x80047458
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type tun struct {
@@ -34,40 +64,18 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
f *os.File
fd int
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
@@ -77,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifr := ifreq{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
return err
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
return nil
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime = addrLifetime{
Vltime: 0xffffffff,
Pltime: 0xffffffff,
}
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
mode := int32(unix.IFF_BROADCAST)
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
if err != nil {
return fmt.Errorf("failed to set tun device mode: %w", err)
}
v := 1
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
if err != nil {
return fmt.Errorf("failed to set tun device head: %w", err)
}
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -191,27 +390,52 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -224,10 +448,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -242,3 +464,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -4,23 +4,50 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
SIOCAIFADDR_IN6 = 0x8080691a
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime [2]uint32
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
@@ -28,48 +55,46 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
@@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
err = addRoute(cidr, t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime[0] = 0xffffffff
req.Lifetime[1] = 0xffffffff
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
@@ -124,63 +297,65 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -192,10 +367,9 @@ func (t *tun) removeRoutes(routes []Route) error {
if !r.Install {
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -204,52 +378,115 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
buf = buf[:len(from)+4]
return
}
if len(from) == 0 {
return 0, syscall.EIO
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
copy(buf[4:], from)
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -132,6 +132,29 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *TestTun) BatchSize() int {
return 1
}

View File

@@ -6,7 +6,6 @@ package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
@@ -234,10 +233,36 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
// BatchRead reads a single packet (batch size 1 for Windows)
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for Windows)
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for Windows (no batching)
func (t *winTun) BatchSize() int {
return 1
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.

View File

@@ -46,10 +46,36 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
return d, nil
}
// BatchRead reads a single packet (batch size 1 for UserDevice)
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := d.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for UserDevice)
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := d.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for UserDevice (no batching)
func (d *UserDevice) BatchSize() int {
return 1
}
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
return d.inboundReader, d.outboundWriter
}

View File

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
}
// Set up the parameters which include the peer's public key

90
pki.go
View File

@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
} else {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
nil,
)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new v1 cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
nil,
)
}
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
//adding certs is fine, actually
} else {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
nil,
)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
nil,
)
}
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
//newState.v1Cert is non-nil bc empty certstates aren't permitted
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
nil,
)
}
}
// Cipher cant be hot swapped so just leave it at what it was before
@@ -173,7 +180,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
@@ -359,7 +365,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
cs.initiatingVersion = dv
}

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand.
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,8 +201,10 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet
cache map[netip.Addr]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
hr *hostnamesResults
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c
}
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr)
}

View File

@@ -6,6 +6,7 @@ import (
"log"
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"strconv"
"time"

View File

@@ -13,12 +13,21 @@ type EncReader func(
payload []byte,
)
type EncBatchReader func(
addrs []netip.AddrPort,
payloads [][]byte,
count int,
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader)
ListenOutBatch(r EncBatchReader)
WriteTo(b []byte, addr netip.AddrPort) error
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
ReloadConfig(c *config.C)
BatchSize() int
Close() error
}
@@ -33,12 +42,21 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) BatchSize() int {
return 1
}
func (NoopConn) Close() error {
return nil
}

View File

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
}
}
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
@@ -184,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
// ListenOutBatch - fallback to single-packet reads for Darwin
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
func (u *StdConn) BatchSize() int {
return 1
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {

View File

@@ -85,3 +85,42 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
// ListenOutBatch - fallback to single-packet reads for generic platforms
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
// WriteMulti sends multiple packets - fallback implementation
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *GenericConn) BatchSize() int {
return 1
}
func (u *GenericConn) Rebind() error {
return nil
}

View File

@@ -22,6 +22,11 @@ type StdConn struct {
isV4 bool
l *logrus.Logger
batch int
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
writeMsgs []rawMessage
writeIovecs []iovec
writeNames [][]byte
}
func maybeIPV4(ip net.IP) (net.IP, bool) {
@@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
c.writeMsgs = make([]rawMessage, batch)
c.writeIovecs = make([]iovec, batch)
c.writeNames = make([][]byte, batch)
for i := range c.writeMsgs {
// Allocate for IPv6 size (larger than IPv4, works for both)
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
// Point to the iovec in the slice
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
c.writeMsgs[i].Hdr.Iovlen = 1
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
// Namelen will be set appropriately in writeMulti4/writeMulti6
}
return c, err
}
func (u *StdConn) Rebind() error {
@@ -127,6 +151,8 @@ func (u *StdConn) ListenOut(r EncReader) {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
for {
n, err := read(msgs)
if err != nil {
@@ -134,6 +160,8 @@ func (u *StdConn) ListenOut(r EncReader) {
return
}
udpBatchHist.Update(int64(n))
for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
@@ -146,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
// Pre-allocate slices for batch callback
addrs := make([]netip.AddrPort, u.batch)
payloads := make([][]byte, u.batch)
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpBatchHist.Update(int64(n))
// Prepare batch data
for i := 0; i < n; i++ {
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payloads[i] = buffers[i][:msgs[i].Len]
}
// Call batch callback with all packets
r(addrs, payloads, n)
}
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
@@ -194,6 +262,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
if len(packets) != len(addrs) {
return 0, fmt.Errorf("packets and addrs length mismatch")
}
if len(packets) == 0 {
return 0, nil
}
if u.isV4 {
return u.writeMulti4(packets, addrs)
}
return u.writeMulti6(packets, addrs)
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@@ -248,6 +329,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
}
}
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
if !addrs[pktIdx].Addr().Is4() {
return sent + i, ErrInvalidIPv6RemoteForSocket
}
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET
rsa.Addr = addrs[pktIdx].Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv4
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET6
rsa.Addr = addrs[pktIdx].Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv6
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
@@ -305,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil
}
func (u *StdConn) BatchSize() int {
return u.batch
}
func (u *StdConn) Close() error {
return syscall.Close(u.sysFd)
}

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint32
Len uint
}
type msghdr struct {
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint64
Len uint
}
type msghdr struct {
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) {
}
}
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
p, ok := <-u.RxPackets
if !ok {
return
}
addrs[0] = p.From
payloads[0] = p.Data
r(addrs, payloads, 1)
}
}
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []Conn) func() {
@@ -127,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) BatchSize() int {
return 1
}
func (u *TesterConn) Rebind() error {
return nil
}