Compare commits

..

18 Commits

Author SHA1 Message Date
Nate Brown
2ea8a72d5c dunno 2025-10-05 23:23:30 -05:00
Nate Brown
663232e1fc Testing the concept 2025-10-05 23:23:10 -05:00
Nate Brown
2f48529e8b Cleanup and note more work 2025-10-05 23:23:08 -05:00
Nate Brown
f3e1ad64cd Try the timeout 2025-10-05 23:22:29 -05:00
Nate Brown
1d8112a329 Revert "More playing" way too much garbage emitted
This reverts commit fa098c551a.
2025-10-05 23:22:29 -05:00
Nate Brown
31eea0cc94 More playing 2025-10-05 23:22:29 -05:00
Nate Brown
dbba4a4c77 Playing 2025-10-05 23:22:29 -05:00
Nate Brown
194fde45da non-blocking io for linux 2025-10-05 23:22:27 -05:00
Nate Brown
f46b83f2c4 Remove more os.Exit calls and give a more reliable wait for stop function 2025-10-05 23:20:43 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
dependabot[bot]
071589f7c7 Bump actions/setup-go from 5 to 6 (#1469)
* Bump actions/setup-go from 5 to 6

Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Hardcode the last one to go v1.25

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-10-02 00:05:12 -05:00
Jack Doan
f1e992f6dd don't require a detailsVpnAddr in a HostUpdateNotification (#1472)
* don't require a detailsVpnAddr in a HostUpdateNotification

* don't send our own addr on HostUpdateNotification for v2
2025-09-29 13:43:12 -05:00
Jack Doan
1ea5f776d7 update to go 1.25, use the cool new ECDSA key marshalling functions (#1483)
* update to go 1.25, use the cool new ECDSA key marshalling functions

* bonk the runners

* actually bump go.mod

* bump golangci-lint
2025-09-29 13:02:25 -05:00
Henry Graham
4cdeb284ef Set CKA_VALUE_LEN attribute in DeriveNoise (#1482) 2025-09-25 13:24:52 -05:00
Jack Doan
5cccd39465 update RemoteList.vpnAddrs when we complete a handshake (#1467) 2025-09-10 09:44:25 -05:00
Jack Doan
8196c22b5a store lighthouses as a slice (#1473)
* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
2025-09-10 09:43:25 -05:00
36 changed files with 1000 additions and 325 deletions

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -35,9 +35,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -68,9 +68,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version-file: 'go.mod'
go-version: '1.25'
check-latest: true
- name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -34,7 +34,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
with:
version: v2.1
version: v2.5
- name: Test
run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.22'
go-version: '1.25'
check-latest: true
- name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- uses: actions/setup-go@v6
with:
go-version: '1.24'
go-version: '1.25'
check-latest: true
- name: Build nebula
@@ -117,7 +117,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
with:
version: v2.1
version: v2.5
- name: Test
run: make test

View File

@@ -5,6 +5,7 @@ import (
"github.com/sirupsen/logrus"
)
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
type Bits struct {
length uint64
current uint64
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
}
// Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
l.Error("rejected a packet (top) %d %d\n", b.current, i)
return false
}

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -1,6 +1,7 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -13,6 +14,7 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -15,6 +15,7 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{

View File

@@ -7,19 +7,26 @@ import (
"golang.org/x/crypto/ed25519"
)
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner:
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
if err != nil {
return nil, err
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -65,8 +65,16 @@ func main() {
}
if !*configTest {
ctrl.Start()
ctrl.ShutdownBlock()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -3,6 +3,9 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"github.com/sirupsen/logrus"
@@ -58,10 +61,22 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
ctrl.Start()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l)
ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10),
}
lighthouses := map[netip.Addr]struct{}{}
lighthouses := []netip.Addr{}
staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/slackhq/nebula/noiseutil"
)
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 1024
type ConnectionState struct {

View File

@@ -2,9 +2,11 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface
c.f.activate()
err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
}
// Start reading packets.
c.f.run()
c.state = Started
c.stateLock.Unlock()
return c.f.run(c.ctx)
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.l.Info("Goodbye")
c.state = Stopped
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

4
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/slackhq/nebula
go 1.23.0
toolchain go1.24.1
go 1.25
require (
dario.cat/mergo v1.0.2

View File

@@ -459,7 +459,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.ResetBlockedRemotes()
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return
}
@@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.ResetBlockedRemotes()
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
return false

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -18,6 +18,7 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp"
)
@@ -87,12 +88,19 @@ type Interface struct {
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
inPool sync.Pool
inbound chan *packet.Packet
outPool sync.Pool
outbound chan *[]byte
}
type EncWriter interface {
@@ -194,9 +202,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
//TODO: configurable size
inbound: make(chan *packet.Packet, 1028),
outbound: make(chan *[]byte, 1028),
l: c.l,
}
ifce.inPool = sync.Pool{New: func() any {
return packet.New()
}}
ifce.outPool = sync.Pool{New: func() any {
t := make([]byte, mtu)
return &t
}}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -209,7 +230,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -230,33 +251,44 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
if err = f.inside.Activate(); err != nil {
f.inside.Close()
f.l.Fatal(err)
return err
}
return nil
}
func (f *Interface) run() {
// Launch n queues to read packets from udp
func (f *Interface) run(c context.Context) (func(), error) {
for i := 0; i < f.routines; i++ {
// Launch n queues to read packets from udp
f.wg.Add(1)
go f.listenOut(i)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerIn(i, c)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerOut(i, c)
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
}
return f.wg.Wait, nil
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -264,41 +296,97 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet)
//TODO: have the listener store this in the msgs array after a read instead of doing a copy
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
p.Payload = p.Payload[:mtu]
copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr
f.inbound <- p
//select {
//case f.inbound <- p:
//default:
// f.l.Error("Dropped packet from inbound channel")
//}
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu]
n, err := reader.Read(*p)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
*p = (*p)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket2 := &firewall.Packet{}
nb2 := make([]byte, 12, 12)
result2 := make([]byte, mtu)
h := &header.H{}
for {
select {
case p := <-f.inbound:
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
case <-ctx.Done():
f.wg.Done()
return
}
}
}
func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12)
result1 := make([]byte, mtu)
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
*data = (*data)[:mtu]
f.outPool.Put(data)
case <-ctx.Done():
f.wg.Done()
return
}
}
}
@@ -451,6 +539,7 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error {
f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers {
err := u.Close()
if err != nil {
@@ -458,6 +547,13 @@ func (f *Interface) Close() error {
}
}
// Release the tun device
return f.inside.Close()
// Release the tun readers
for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
)
var ErrHostNotKnown = errors.New("host not known")
@@ -58,7 +57,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[[]netip.Addr]
interval atomic.Int64
updateCancel context.CancelFunc
@@ -109,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make(map[netip.Addr]struct{})
lighthouses := make([]netip.Addr, 0)
h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
@@ -145,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
func (lh *LightHouse) GetLighthouses() []netip.Addr {
return *lh.lighthouses.Load()
}
@@ -230,7 +229,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.updateCancel()
}
lh.StartUpdateWorkerNetlink()
lh.StartUpdateWorker()
}
}
@@ -308,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lhMap)
lhList, err := lh.parseLighthouses(c)
if err != nil {
return err
}
lh.lighthouses.Store(&lhMap)
lh.lighthouses.Store(&lhList)
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
@@ -348,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
out := make([]netip.Addr, len(lhs))
for i, host := range lhs {
addr, err := netip.ParseAddr(host)
if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(addr) {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
lhMap[addr] = struct{}{}
out[i] = addr
}
if !lh.amLighthouse && len(lhMap) == 0 {
if !lh.amLighthouse && len(out) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap {
if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
for i := range out {
if _, ok := staticList[out[i]]; !ok {
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
}
}
return nil
return out, nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -488,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs)
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -571,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
continue
}
switch {
@@ -646,18 +645,17 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
}
}
//TODO lighthouse.remote_allow_ranges is almost certainly broken in a multiple-address-per-cert scenario
am := NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
am := NewRemoteList(allAddrs, lh.shouldAdd)
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
return am
}
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
}
if !allow {
@@ -712,15 +710,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
}
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
_, ok := lh.GetLighthouses()[vpnAddr]
return ok
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true
}
}
return false
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
l := lh.GetLighthouses()
for _, a := range vpnAddr {
if _, ok := l[a]; ok {
return true
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
return true
}
}
}
return false
@@ -762,7 +767,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0
lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses {
for _, lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
@@ -846,65 +851,10 @@ func (lh *LightHouse) StartUpdateWorker() {
}()
}
func (lh *LightHouse) StartUpdateWorkerNetlink() {
if lh.amLighthouse {
return
}
addrChan := make(chan netlink.AddrUpdate)
doneChan := make(chan struct{})
err := netlink.AddrSubscribe(addrChan, doneChan)
if err != nil {
lh.l.WithError(err).Error("Failed to subscribe to address updates")
return
}
updateCtx, cancel := context.WithCancel(lh.ctx)
lh.updateCancel = cancel
go func() {
defer func() {
x := struct{}{}
doneChan <- x
}()
lh.SendUpdate()
for {
select {
case <-updateCtx.Done():
return
case a := <-addrChan:
addr, ok := netip.AddrFromSlice(a.LinkAddress.IP)
if !ok {
continue
}
ones, _ := a.LinkAddress.Mask.Size()
pfx := netip.PrefixFrom(addr, ones)
lh.l.WithFields(logrus.Fields{
"LinkAddress": pfx,
"LinkIndex": a.LinkIndex,
"Flags": a.Flags,
"Scope": a.Scope,
"NewAddr": a.NewAddr,
"ValidLft": a.ValidLft,
"PreferedLft": a.PreferedLft,
}).Warn("Received address update")
shouldSend := a.PreferedLft != 0 //example criteria
if shouldSend {
lh.SendUpdate()
}
}
}
}()
}
func (lh *LightHouse) SendUpdate() {
var v4 []*V4AddrPort
var v6 []*V6AddrPort
lh.l.Warn("sending lh update!")
for _, e := range lh.GetAdvertiseAddrs() {
if e.Addr().Is4() {
v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port()))
@@ -935,7 +885,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0
lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses {
for _, lhVpnAddr := range lighthouses {
var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
@@ -993,7 +943,6 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4,
V6AddrPorts: v6,
RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
},
}
@@ -1116,7 +1065,16 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
}
return
}
@@ -1125,9 +1083,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else {
@@ -1287,16 +1242,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return
}
detailsVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostUpdateNotification")
}
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr
var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
}
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
}
@@ -1310,24 +1273,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
if useVersion == cert.Version1 {
switch useVersion {
case cert.Version1:
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
} else {
case cert.Version2:
// do nothing, we want to send a blank message
default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return
}
@@ -1345,7 +1308,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}

18
main.go
View File

@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorkerNetlink,
connManager.Start,
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
}, nil
}

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
//f.l.Error("in packet ", h)
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
return
}
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
return false
}

View File

@@ -1,6 +1,7 @@
package overlay
import (
"net"
"net/netip"
"github.com/sirupsen/logrus"
@@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
@@ -554,13 +553,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -10,11 +10,9 @@ import (
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
)
type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
}
type ifreqRename struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
Data uintptr
}
type ifreqDestroy struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
pad [16]byte
}
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
io.ReadWriteCloser
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil {
return err
t.l.WithError(err).Error("Error closing device")
}
defer syscall.Close(s)
t.devFd = -1
ifreq := ifreqDestroy{Name: t.deviceBytes()}
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
}
return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var fd int
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
}
if err != nil {
return nil, err
}
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
// Read the name of the interface
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, err
}
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
}
err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
if cidr.Addr().Is4() {
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
if cidr.Addr().Is6() {
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +500,144 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

12
packet/packet.go Normal file
View File

@@ -0,0 +1,12 @@
package packet
import "net/netip"
type Packet struct {
Payload []byte
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{Payload: make([]byte, 9001)}
}

View File

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
}
// Set up the parameters which include the peer's public key

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand.
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,8 +201,10 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet
cache map[netip.Addr]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
hr *hostnamesResults
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c
}
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr)
}

View File

@@ -9,10 +9,13 @@ import (
"math"
"net"
"net/netip"
"os"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer"
@@ -43,8 +46,19 @@ type Service struct {
}
}
func New(control *nebula.Control) (*Service, error) {
control.Start()
func New(config *config.C) (*Service, error) {
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +155,12 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil
}

View File

@@ -16,7 +16,7 @@ type EncReader func(
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader)
ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error

View File

@@ -71,15 +71,14 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) {
func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/netip"
"syscall"
"time"
"unsafe"
"github.com/rcrowley/go-metrics"
@@ -17,6 +18,8 @@ import (
"golang.org/x/sys/unix"
)
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
type StdConn struct {
sysFd int
isV4 bool
@@ -24,14 +27,6 @@ type StdConn struct {
batch int
}
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
af := unix.AF_INET6
if ip.Is4() {
@@ -55,6 +50,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
}
}
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr
if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port}
@@ -118,7 +118,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
@@ -130,8 +130,7 @@ func (u *StdConn) ListenOut(r EncReader) {
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
for i := 0; i < n; i++ {
@@ -159,6 +158,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err}
}
@@ -180,6 +182,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err}
}
@@ -221,7 +226,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var rsa unix.RawSockaddrInet4

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) {
func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {