Compare commits

..

3 Commits

Author SHA1 Message Date
Jack Doan
4c745e8cfe maintain existing punchy 2025-06-09 12:35:28 -04:00
Jack Doan
87a4ec7d90 use queried hostmap info for deletion logging 2025-06-09 12:28:02 -04:00
Ryan Huber
47d4055e10 remove unused and stale tunnels. punch less. 2025-06-09 12:27:13 -04:00
67 changed files with 1187 additions and 4406 deletions

View File

@@ -17,5 +17,5 @@ contact_links:
about: 'The documentation is the best place to start if you are new to Nebula.'
- name: 💁 Support/Chat
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ
about: 'For faster support, join us on Slack for assistance!'

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -35,9 +35,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -68,9 +68,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version-file: 'go.mod'
check-latest: true
- name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -32,9 +32,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v7
with:
version: v2.5
version: v2.0
- name: Test
run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.22'
check-latest: true
- name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build nebula
@@ -115,9 +115,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v7
with:
version: v2.5
version: v2.0
- name: Test
run: make test

View File

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
## Supported Platforms

View File

@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[string]any)
rawMap, ok := value.(map[any]any)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawCIDR, rawValue := range rawMap {
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
}
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[string]any)
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
}

View File

@@ -58,9 +58,6 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default:
return nil, ErrUnknownVersion
//TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil {

View File

@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -1,7 +1,6 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -14,7 +13,6 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -15,7 +15,6 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{

View File

@@ -20,7 +20,6 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -7,26 +7,19 @@ import (
"golang.org/x/crypto/ed25519"
)
const ( //cert banners
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -58,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -79,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -103,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
case P256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256

View File

@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
if err != nil {
return nil, err
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
}
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -4,10 +4,8 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
@@ -29,127 +27,311 @@ const (
sendTestPacket trafficDecision = 6
)
// LastCommunication tracks when we last communicated with a host
type LastCommunication struct {
timestamp time.Time
vpnIp netip.Addr // To help with logging
}
type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex
out map[uint32]struct{}
outLock *sync.RWMutex
// relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex
// Track last communication with hosts
lastCommMap map[uint32]time.Time
lastCommLock *sync.RWMutex
inactivityTimer *LockingTimerWheel[uint32]
inactivityTimeout time.Duration
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy
// Configuration settings
checkInterval time.Duration
pendingDeletionInterval time.Duration
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
l *logrus.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
punchy: p,
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
var max time.Duration
if checkInterval < pendingDeletionInterval {
max = pendingDeletionInterval
} else {
max = checkInterval
}
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
lastCommMap: make(map[uint32]time.Time),
lastCommLock: &sync.RWMutex{},
inactivityTimeout: 1 * time.Minute, // Default inactivity timeout: 10 minutes
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l,
}
cm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
cm.reload(c, false)
})
// Initialize the inactivity timer wheel - make wheel duration slightly longer than the timeout
nc.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, nc.inactivityTimeout+time.Minute)
return cm
nc.Start(ctx)
return nc
}
func (cm *connectionManager) reload(c *config.C, initial bool) {
if initial {
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
// pretty close to their configured duration.
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
}
if initial || c.HasChanged("tunnels.inactivity_timeout") {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
}
}
if initial || c.HasChanged("tunnels.drop_inactive") {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
}
}
}
func (cm *connectionManager) getInactivityTimeout() time.Duration {
return (time.Duration)(cm.inactivityTimeout.Load())
}
func (cm *connectionManager) In(h *HostInfo) {
h.in.Store(true)
}
func (cm *connectionManager) Out(h *HostInfo) {
h.out.Store(true)
}
func (cm *connectionManager) RelayUsed(localIndex uint32) {
cm.relayUsedLock.RLock()
// If this already exists, return
if _, ok := cm.relayUsed[localIndex]; ok {
cm.relayUsedLock.RUnlock()
func (n *connectionManager) updateLastCommunication(localIndex uint32) {
// Get host info to record VPN IP for better logging
hostInfo := n.hostMap.QueryIndex(localIndex)
if hostInfo == nil {
return
}
cm.relayUsedLock.RUnlock()
cm.relayUsedLock.Lock()
cm.relayUsed[localIndex] = struct{}{}
cm.relayUsedLock.Unlock()
now := time.Now()
n.lastCommLock.Lock()
n.lastCommMap[localIndex] = now
n.lastCommLock.Unlock()
// Reset the inactivity timer for this host
n.inactivityTimer.m.Lock()
n.inactivityTimer.t.Add(localIndex, n.inactivityTimeout)
n.inactivityTimer.m.Unlock()
}
func (n *connectionManager) In(localIndex uint32) {
n.inLock.RLock()
// If this already exists, return
if _, ok := n.in[localIndex]; ok {
n.inLock.RUnlock()
return
}
n.inLock.RUnlock()
n.inLock.Lock()
n.in[localIndex] = struct{}{}
n.inLock.Unlock()
// Update last communication time
n.updateLastCommunication(localIndex)
}
func (n *connectionManager) Out(localIndex uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[localIndex]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
n.out[localIndex] = struct{}{}
n.outLock.Unlock()
// Update last communication time
n.updateLastCommunication(localIndex)
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
}
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
in := h.in.Swap(false)
out := h.out.Swap(false)
if in || out {
h.lastUsed = now
}
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
n.inLock.Lock()
n.outLock.Lock()
_, in := n.in[localIndex]
_, out := n.out[localIndex]
delete(n.in, localIndex)
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out
}
// AddTrafficWatch must be called for every new HostInfo.
// We will continue to monitor the HostInfo until the tunnel is dropped.
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
if h.out.Swap(true) == false {
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
n.outLock.Lock()
if _, ok := n.out[localIndex]; ok {
n.outLock.Unlock()
return
}
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
}
// checkInactiveTunnels checks for tunnels that have been inactive for too long and drops them
func (n *connectionManager) checkInactiveTunnels() {
now := time.Now()
// First, advance the timer wheel to the current time
n.inactivityTimer.m.Lock()
n.inactivityTimer.t.Advance(now)
n.inactivityTimer.m.Unlock()
// Check for expired timers (inactive connections)
for {
// Get the next expired tunnel
n.inactivityTimer.m.Lock()
localIndex, ok := n.inactivityTimer.t.Purge()
n.inactivityTimer.m.Unlock()
if !ok {
// No more expired timers
break
}
n.lastCommLock.RLock()
lastComm, exists := n.lastCommMap[localIndex]
n.lastCommLock.RUnlock()
if !exists {
// No last communication record, odd but skip
continue
}
// Calculate inactivity duration
inactiveDuration := now.Sub(lastComm)
// Check if we've exceeded the inactivity timeout
if inactiveDuration >= n.inactivityTimeout {
// Get the host info (if it still exists)
hostInfo := n.hostMap.QueryIndex(localIndex)
if hostInfo == nil {
// Host info is gone, remove from our tracking map
n.lastCommLock.Lock()
delete(n.lastCommMap, localIndex)
n.lastCommLock.Unlock()
continue
}
// Log the inactivity and drop the tunnel
n.l.WithField("vpnIp", hostInfo.vpnAddrs[0]).
WithField("localIndex", localIndex).
WithField("inactiveDuration", inactiveDuration).
WithField("timeout", n.inactivityTimeout).
Info("Dropping tunnel due to inactivity")
// Close the tunnel using the existing mechanism
n.intf.closeTunnel(hostInfo)
// Clean up our tracking map
n.lastCommLock.Lock()
delete(n.lastCommMap, localIndex)
n.lastCommLock.Unlock()
} else {
// Re-add to the timer wheel with the remaining time
remainingTime := n.inactivityTimeout - inactiveDuration
n.inactivityTimer.m.Lock()
n.inactivityTimer.t.Add(localIndex, remainingTime)
n.inactivityTimer.m.Unlock()
}
}
}
func (cm *connectionManager) Start(ctx context.Context) {
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
// CleanupDeletedHostInfos removes entries from our lastCommMap for hosts that no longer exist
func (n *connectionManager) CleanupDeletedHostInfos() {
n.lastCommLock.Lock()
defer n.lastCommLock.Unlock()
// Find indexes to delete
var toDelete []uint32
for localIndex := range n.lastCommMap {
if n.hostMap.QueryIndex(localIndex) == nil {
toDelete = append(toDelete, localIndex)
}
}
// Delete them
for _, localIndex := range toDelete {
delete(n.lastCommMap, localIndex)
}
if len(toDelete) > 0 && n.l.Level >= logrus.DebugLevel {
n.l.WithField("count", len(toDelete)).Debug("Cleaned up deleted host entries from lastCommMap")
}
}
// ReloadConfig updates the connection manager configuration
func (n *connectionManager) ReloadConfig(c *config.C) {
// Get the inactivity timeout from config
inactivityTimeout := c.GetDuration("timers.inactivity_timeout", 10*time.Minute)
// Only update if different
if inactivityTimeout != n.inactivityTimeout {
n.l.WithField("old", n.inactivityTimeout).
WithField("new", inactivityTimeout).
Info("Updating inactivity timeout")
n.inactivityTimeout = inactivityTimeout
// Recreate the inactivity timer wheel with the new timeout
n.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, n.inactivityTimeout+time.Minute)
// Re-add all existing hosts to the new timer wheel
n.lastCommLock.RLock()
for localIndex, lastComm := range n.lastCommMap {
// Calculate remaining time based on last communication
now := time.Now()
elapsed := now.Sub(lastComm)
// If the elapsed time exceeds the new timeout, this will be caught
// in the next inactivity check. Otherwise, add with remaining time.
if elapsed < n.inactivityTimeout {
remainingTime := n.inactivityTimeout - elapsed
n.inactivityTimer.m.Lock()
n.inactivityTimer.t.Add(localIndex, remainingTime)
n.inactivityTimer.m.Unlock()
}
}
n.lastCommLock.RUnlock()
}
}
func (n *connectionManager) Start(ctx context.Context) {
go n.Run(ctx)
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
// Create ticker for inactivity checks (every minute)
inactivityTicker := time.NewTicker(time.Minute)
defer inactivityTicker.Stop()
// Create ticker for cleanup (every 5 minutes)
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@@ -160,61 +342,69 @@ func (cm *connectionManager) Start(ctx context.Context) {
return
case now := <-clockSource.C:
cm.trafficTimer.Advance(now)
n.trafficTimer.Advance(now)
for {
localIndex, has := cm.trafficTimer.Purge()
localIndex, has := n.trafficTimer.Purge()
if !has {
break
}
cm.doTrafficCheck(localIndex, p, nb, out, now)
n.doTrafficCheck(localIndex, p, nb, out, now)
}
case <-inactivityTicker.C:
// Check for inactive tunnels
n.checkInactiveTunnels()
case <-cleanupTicker.C:
// Periodically clean up deleted hosts
n.CleanupDeletedHostInfos()
}
}
}
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
switch decision {
case deleteTunnel:
if cm.hostMap.DeleteHostInfo(hostinfo) {
if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
}
case closeTunnel:
cm.intf.sendCloseTunnel(hostinfo)
cm.intf.closeTunnel(hostinfo)
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
case swapPrimary:
cm.swapPrimary(hostinfo, primary)
n.swapPrimary(hostinfo, primary)
case migrateRelays:
cm.migrateRelayUsed(hostinfo, primary)
n.migrateRelayUsed(hostinfo, primary)
case tryRehandshake:
cm.tryRehandshake(hostinfo)
n.tryRehandshake(hostinfo)
case sendTestPacket:
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
}
cm.resetRelayTrafficCheck(hostinfo)
n.resetRelayTrafficCheck(hostinfo)
}
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil {
cm.relayUsedLock.Lock()
defer cm.relayUsedLock.Unlock()
n.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(cm.relayUsed, idx)
delete(n.relayUsed, idx)
}
}
}
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
@@ -224,51 +414,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var relayFrom netip.Addr
var relayTo netip.Addr
switch {
case ok:
switch existing.State {
case Established, PeerRequested, Disestablished:
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
continue
case Requested:
case ok && existing.State == Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
}
case !ok:
cm.relayUsedLock.RLock()
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
n.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it.
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
continue
}
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerAddr
case ForwardingType:
relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
}
@@ -281,12 +466,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
@@ -298,16 +483,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -318,44 +503,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
}
}
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
cm.hostMap.RLock()
defer cm.hostMap.RUnlock()
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()
hostinfo := cm.hostMap.Indexes[localIndex]
hostinfo := n.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return doNothing, nil, nil
}
if cm.isInvalidCertificate(now, hostinfo) {
if n.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil
}
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// Check for traffic on this hostinfo
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
hostinfo.pendingDeletion.Store(false)
delete(n.pendingDeletion, hostinfo.localIndexId)
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo) {
if n.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
@@ -363,55 +550,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
}
}
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
if !outTraffic {
// Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
return decision, hostinfo, primary
}
if hostinfo.pendingDeletion.Load() {
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
delete(n.pendingDeletion, hostinfo.localIndexId)
return deleteTunnel, hostinfo, nil
}
decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic {
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
return closeTunnel, hostinfo, primary
}
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.sendPunch(hostinfo)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
if n.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
}
@@ -420,33 +598,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
}
hostinfo.pendingDeletion.Store(true)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return decision, hostinfo, nil
}
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
if cm.dropInactive.Load() == false {
// We aren't configured to drop inactive tunnels
return 0, false
}
inactiveDuration := now.Sub(hostinfo.lastUsed)
if inactiveDuration < cm.getInactivityTimeout() {
// It's not considered inactive
return inactiveDuration, false
}
// The tunnel is inactive
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
@@ -454,127 +616,83 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap.
return false
}
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Lock()
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
cm.hostMap.unlockedMakePrimary(current)
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current)
}
cm.hostMap.Unlock()
n.hostMap.Unlock()
}
// isInvalidCertificate decides if we should destroy a tunnel.
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true.
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert()
if remoteCert == nil {
return false //don't tear down tunnels for handshakes in progress
}
caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
caPool := n.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected
return false
}
hostinfo.logger(n.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
}
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
n.metricsTxPunchy.Inc(1)
_ = n.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
n.metricsTxPunchy.Inc(1)
_ = n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@@ -1,6 +1,7 @@
package nebula
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net/netip"
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10),
}
lighthouses := []netip.Addr{}
lighthouses := map[netip.Addr]struct{}{}
staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
@@ -63,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
assert.Contains(t, nc.out, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load())
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
}
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.in.Load())
assert.True(t, hostinfo.out.Load())
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion
nc.In(hostinfo)
nc.In(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnAddrs: vpnAddrs,
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
now := time.Now()
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
assert.Equal(t, tryRehandshake, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, should still not be pending deletion
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Finally advance beyond the inactivity timeout
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
assert.Equal(t, closeTunnel, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo := &HostInfo{
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -34,7 +34,6 @@ type Control struct {
statsStart func()
dnsStart func()
lighthouseStart func()
connectionManagerStart func(context.Context)
}
type ControlHostInfo struct {
@@ -64,9 +63,6 @@ func (c *Control) Start() {
if c.dnsStart != nil {
go c.dnsStart()
}
if c.connectionManagerStart != nil {
go c.connectionManagerStart(c.ctx)
}
if c.lighthouseStart != nil {
c.lighthouseStart()
}

View File

@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},

View File

@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
curIndexes := len(myControl.GetHostmap().Indexes)
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -1052,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it")

View File

@@ -129,109 +129,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
return control, vpnNetworks, udpAddr, c
}
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb {

View File

@@ -700,7 +700,6 @@ func (r *R) FlushAll() {
r.Unlock()
panic("Can't FlushAll for host: " + p.To.String())
}
receiver.InjectUDPPacket(p)
r.Unlock()
}
}

View File

@@ -1,320 +0,0 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestDropInactiveTunnels(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Go inactive and wait for the tunnels to get dropped")
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*30 {
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
myControl.Stop()
theirControl.Stop()
}
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -338,18 +338,6 @@ logging:
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Tunnel manager settings
#tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
# elapsed.
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
# This setting is reloadable
#drop_inactive: false
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
# inactive and eligible to be dropped.
# This setting is reloadable
#inactivity_timeout: 10m
# Nebula security group configuration
firewall:

View File

@@ -68,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -95,24 +92,12 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -132,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -221,82 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{
@@ -306,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -341,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -363,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -388,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -424,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -593,42 +447,6 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
@@ -692,50 +510,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
},
}
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.0.2.1"),
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ {
@@ -953,21 +727,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}

43
go.mod
View File

@@ -1,37 +1,39 @@
module github.com/slackhq/nebula
go 1.25
go 1.23.0
toolchain go1.24.1
require (
dario.cat/mergo v1.0.2
dario.cat/mergo v1.0.1
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.25.0
github.com/gaissmai/bart v0.20.4
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.68
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_golang v1.22.0
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.43.0
github.com/stretchr/testify v1.10.0
github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.45.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
golang.org/x/net v0.39.0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.8
google.golang.org/protobuf v1.36.6
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
@@ -43,12 +45,11 @@ require (
github.com/google/btree v1.1.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.33.0 // indirect
golang.org/x/tools v0.30.0 // indirect
)

81
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -143,33 +143,29 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -180,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -189,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -201,17 +197,18 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -222,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -242,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -23,19 +23,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
}
crt := cs.getCertificate(v)
if crt == nil {
@@ -52,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -108,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -149,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
@@ -169,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"udpAddr": addr,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
} else {
// Record the certificate we are actually using
ci.myCert = myCertOtherVersion
}
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 {
@@ -258,7 +249,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -466,9 +457,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message sent")
}
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
hostinfo.remotes.ResetBlockedRemotes()
return
}
@@ -661,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
@@ -676,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
hostinfo.remotes.ResetBlockedRemotes()
f.metricHandshakes.Update(duration)
return false

View File

@@ -70,7 +70,6 @@ type HandshakeHostInfo struct {
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
@@ -451,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},

View File

@@ -4,7 +4,6 @@ import (
"errors"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
@@ -17,10 +16,12 @@ import (
"github.com/slackhq/nebula/header"
)
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -67,7 +68,7 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held)
@@ -78,12 +79,7 @@ type RelayState struct {
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
for idx, val := range rs.relays {
if val == ip {
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
return
}
}
delete(rs.relays, ip)
}
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
if !slices.Contains(rs.relays, ip) {
rs.relays = append(rs.relays, ip)
}
rs.relays[ip] = struct{}{}
}
func (rs *RelayState) CopyRelayIps() []netip.Addr {
ret := make([]netip.Addr, len(rs.relays))
rs.RLock()
defer rs.RUnlock()
copy(ret, rs.relays)
ret := make([]netip.Addr, 0, len(rs.relays))
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret
}
@@ -224,6 +220,7 @@ type HostInfo struct {
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
@@ -253,14 +250,6 @@ type HostInfo struct {
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
in, out, pendingDeletion atomic.Bool
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
// This value will be behind against actual tunnel utilization in the hot path.
// This should only be used by the ConnectionManagers ticker routine.
lastUsed time.Time
}
type ViaSender struct {
@@ -730,6 +719,13 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError.Add(1) >= maxRecvError {
return true
}
return true
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
@@ -738,8 +734,7 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
i.networks = new(bart.Lite)
for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
i.networks.Insert(nprefix)
i.networks.Insert(network)
}
for _, network := range unsafeNetworks {

View File

@@ -7,7 +7,6 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostMap_MakePrimary(t *testing.T) {
@@ -216,31 +215,3 @@ func TestHostMap_reload(t *testing.T) {
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}
func TestHostMap_RelayState(t *testing.T) {
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
a1 := netip.MustParseAddr("::1")
a2 := netip.MustParseAddr("2001::1")
h1.relayState.InsertRelayTo(a1)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
h1.relayState.InsertRelayTo(a2)
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
// Ensure that the first relay added is the first one returned in the copy
currentRelays := h1.relayState.CopyRelayIps()
require.Len(t, currentRelays, 2)
assert.Equal(t, a1, currentRelays[0])
// Deleting the last one in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting an element not in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting the only element in the list works ok
h1.relayState.DeleteRelay(a1)
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
}

147
inside.go
View File

@@ -11,149 +11,6 @@ import (
"github.com/slackhq/nebula/routing"
)
// consumeInsidePackets processes multiple packets in a batch for improved performance
// packets: slice of packet buffers to process
// sizes: slice of packet sizes
// count: number of packets to process
// outs: slice of output buffers (one per packet) with virtio headroom
// q: queue index
// localCache: firewall conntrack cache
// batchPackets: pre-allocated slice for accumulating encrypted packets
// batchAddrs: pre-allocated slice for accumulating destination addresses
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
// Reusable per-packet state
fwPacket := &firewall.Packet{}
// Reset batch accumulation slices (reuse capacity)
*batchPackets = (*batchPackets)[:0]
*batchAddrs = (*batchAddrs)[:0]
// Process each packet in the batch
for i := 0; i < count; i++ {
packet := packets[i][:sizes[i]]
out := outs[i]
// Inline the consumeInsidePacket logic for better performance
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
continue
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
continue
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
// Immediately forward packets from self to self.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
continue
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
continue
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
continue
}
if !ready {
continue
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
continue
}
// Encrypt and prepare packet for batch sending
ci := hostinfo.ConnectionState
if ci.eKey == nil {
continue
}
// Check if this needs relay - if so, send immediately and skip batching
useRelay := !hostinfo.remote.IsValid()
if useRelay {
// Handle relay sends individually (less common path)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
continue
}
// Encrypt the packet for batch sending
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to batch
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
// Send all accumulated packets in one batch
if len(*batchPackets) > 0 {
batchSize := len(*batchPackets)
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
}
}
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
@@ -431,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via)
f.connectionManager.Out(via.localIndexId)
// Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -499,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"runtime"
@@ -21,19 +22,18 @@ import (
)
const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct {
HostMap *HostMap
Outside udp.Conn
Inside overlay.Device
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
checkInterval time.Duration
pendingDeletionInterval time.Duration
DropLocalBroadcast bool
DropMulticast bool
routines int
@@ -50,13 +50,6 @@ type InterfaceConfig struct {
l *logrus.Logger
}
type batchMetrics struct {
udpReadSize metrics.Histogram
tunReadSize metrics.Histogram
udpWriteSize metrics.Histogram
tunWriteSize metrics.Histogram
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
@@ -93,12 +86,11 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []overlay.BatchReadWriter
readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
batchMetrics *batchMetrics
l *logrus.Logger
}
@@ -165,9 +157,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
if c.connectionManager == nil {
return nil, errors.New("no connection manager")
}
cs := c.pki.getCertState()
ifce := &Interface{
@@ -185,14 +174,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]overlay.BatchReadWriter, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -201,12 +190,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
batchMetrics: &batchMetrics{
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
},
l: c.l,
}
@@ -215,7 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager.intf = ifce
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
return ifce, nil
}
@@ -239,7 +222,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader overlay.BatchReadWriter = f.inside
var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
@@ -280,69 +263,39 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
// Pre-allocate output buffers for batch processing
batchSize := li.BatchSize()
outs := make([][]byte, batchSize)
for idx := range outs {
// Allocate full buffer with virtio header space
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
}
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12)
nb := make([]byte, 12, 12)
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
batchSize := reader.BatchSize()
// Allocate buffers for batch reading
bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Allocate output buffers for batch processing (one per packet)
// Each has virtio header headroom to avoid copies on write
outs := make([][]byte, batchSize)
for idx := range outs {
outBuf := make([]byte, virtioNetHdrLen+mtu)
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
}
// Pre-allocate batch accumulation buffers for sending
batchPackets := make([][]byte, 0, batchSize)
batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12)
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.BatchRead(bufs, sizes)
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
f.l.WithError(err).Error("Error while batch reading outbound packets")
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.batchMetrics.tunReadSize.Update(int64(n))
// Process all packets in the batch at once
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}

View File

@@ -24,7 +24,6 @@ import (
)
var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -57,7 +56,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[[]netip.Addr]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make([]netip.Addr, 0)
lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
func (lh *LightHouse) GetLighthouses() []netip.Addr {
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load()
}
@@ -307,12 +306,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
lhList, err := lh.parseLighthouses(c)
lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lhMap)
if err != nil {
return err
}
lh.lighthouses.Store(&lhList)
lh.lighthouses.Store(&lhMap)
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
@@ -346,37 +346,36 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
out := make([]netip.Addr, len(lhs))
for i, host := range lhs {
addr, err := netip.ParseAddr(host)
if err != nil {
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(addr) {
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
out[i] = addr
lhMap[addr] = struct{}{}
}
if !lh.amLighthouse && len(out) == 0 {
if !lh.amLighthouse && len(lhMap) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
staticList := lh.GetStaticHostList()
for i := range out {
if _, ok := staticList[out[i]]; !ok {
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
for lhAddr, _ := range lhMap {
if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
}
}
return out, nil
return nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -487,7 +486,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
return lh.unlockedGetRemoteList(vpnAddrs)
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -520,15 +519,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
}
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
staticList := lh.GetStaticHostList()
for _, addr := range allVpnAddrs {
if _, ok := staticList[addr]; ok {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
return
}
}
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
@@ -570,7 +565,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
continue
}
switch {
@@ -632,30 +627,23 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
// unlockedGetRemoteList
// assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
for i, addr := range allAddrs {
am, ok := lh.addrMap[addr]
if ok {
if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
}
}
am := NewRemoteList(allAddrs, lh.shouldAdd)
am, ok := lh.addrMap[allAddrs[0]]
if !ok {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
}
return am
}
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
}
if !allow {
@@ -710,24 +698,21 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
}
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
return true
}
}
return false
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
l := lh.GetLighthouses()
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
for _, a := range vpnAddr {
if _, ok := l[a]; ok {
return true
}
}
}
return false
}
@@ -767,7 +752,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
@@ -885,7 +870,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
@@ -943,6 +928,7 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4,
V6AddrPorts: v6,
RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
},
}
@@ -1062,19 +1048,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
useVersion := cert.Version1
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
}
return
}
@@ -1083,6 +1069,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else {
@@ -1127,9 +1116,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
}
@@ -1188,17 +1176,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() {
continue
}
b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
}
} else if v == cert.Version2 {
for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
//TODO: CERT-V2 don't panic
panic("unsupported version")
}
}
}
@@ -1208,16 +1198,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return
}
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
}
return
lhh.lh.Lock()
var certVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock()
lhh.lh.Unlock()
@@ -1242,24 +1234,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return
}
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr
var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
useVersion := cert.Version1
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
} else if n.Details.VpnAddr != nil {
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
}
return
}
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
}
@@ -1273,24 +1268,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
case cert.Version1:
if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
case cert.Version2:
// do nothing, we want to send a blank message
default:
} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
} else {
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return
}
@@ -1308,20 +1303,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
punch := func(vpnPeer netip.AddrPort) {
if !vpnPeer.IsValid() {
return
}
@@ -1333,38 +1321,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}()
if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
punch(protoV4AddrPortToNetAddrPort(a))
}
for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
punch(protoV6AddrPortToNetAddrPort(a))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
@@ -1443,17 +1441,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
}
return netip.Addr{}, false
}
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

14
main.go
View File

@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
sshStart = nil
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
}
}
@@ -165,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
@@ -186,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -222,6 +220,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
@@ -230,8 +231,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
@@ -242,6 +244,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
}
@@ -293,6 +296,5 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
}, nil
}

View File

@@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -95,7 +95,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -137,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
@@ -159,7 +160,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
@@ -202,7 +203,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -212,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -253,18 +254,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
}
return false
} else {
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
@@ -313,11 +312,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
next := 0
for {
if protoAt >= dataLen {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
@@ -365,7 +365,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -373,7 +373,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
default:
// Normal ipv6 header length processing
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -473,11 +473,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
@@ -492,7 +490,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
@@ -501,7 +499,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
@@ -540,6 +538,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return
}
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
@@ -549,108 +551,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
// and writes all successfully decrypted packets to TUN in a single operation
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
// Pre-allocate slice for accumulating successful decryptions
tunPackets := make([][]byte, 0, count)
for i := 0; i < count; i++ {
payload := payloads[i]
addr := addrs[i]
out := outs[i]
// Parse header
err := h.Parse(payload)
if err != nil {
if len(payload) > 1 {
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
continue
}
if addr.IsValid() {
if f.myVpnNetworksTable.Contains(addr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
}
continue
}
}
var hostinfo *HostInfo
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, addr, h) {
continue
}
switch h.Subtype {
case header.MessageNone:
// Decrypt packet
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
continue
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
continue
}
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
continue
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
}
continue
}
f.connectionManager.In(hostinfo)
// Add to batch for TUN write
tunPackets = append(tunPackets, out)
case header.MessageRelay:
// Skip relay packets in batch mode for now (less common path)
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
default:
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
}
default:
// Handle non-Message types using single-packet path
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
}
}
if len(tunPackets) > 0 {
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
}
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
}
}

View File

@@ -117,45 +117,6 @@ func Test_newPacket_v6(t *testing.T) {
err = newPacket(buffer.Bytes(), true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A v6 packet with a hop-by-hop extension
// ICMPv6 Payload (Echo Request)
icmpLayer := layers.ICMPv6{
TypeCode: layers.ICMPv6TypeEchoRequest,
}
// Hop-by-Hop Extension Header
hopOption := layers.IPv6HopByHopOption{}
hopOption.OptionData = []byte{0, 0, 0, 0}
hopByHop := layers.IPv6HopByHop{}
hopByHop.Options = append(hopByHop.Options, &hopOption)
ip = layers.IPv6{
Version: 6,
HopLimit: 128,
NextHeader: layers.IPProtocolIPv6Destination,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: true,
}, &ip, &hopByHop, &icmpLayer)
if err != nil {
panic(err)
}
// Ensure buffer length checks during parsing with the next 2 tests.
// A full IPv6 header and 1 byte in the first extension, but missing
// the length byte.
err = newPacket(buffer.Bytes()[:41], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A full IPv6 header plus 1 full extension, but only 1 byte of the
// next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet
ip = layers.IPv6{
Version: 6,
@@ -327,10 +288,6 @@ func Test_newPacket_v6(t *testing.T) {
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Ensure buffer bounds checking during processing
err = newPacket(b[:41], true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// Invalid AH header
b = buffer.Bytes()
err = newPacket(b, true, p)

View File

@@ -7,25 +7,11 @@ import (
"github.com/slackhq/nebula/routing"
)
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
type BatchReadWriter interface {
io.ReadWriteCloser
// BatchRead reads multiple packets at once
BatchRead(bufs [][]byte, sizes []int) (int, error)
// WriteBatch writes multiple packets at once
WriteBatch(bufs [][]byte, offset int) (int, error)
// BatchSize returns the optimal batch size for this device
BatchSize() int
}
type Device interface {
BatchReadWriter
io.ReadWriteCloser
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (BatchReadWriter, error)
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -1,8 +1,6 @@
package overlay
import (
"fmt"
"net"
"net/netip"
"github.com/sirupsen/logrus"
@@ -11,7 +9,6 @@ import (
)
const DefaultMTU = 1300
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
@@ -73,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gateway := range gateways {
if dest.Addr().Is4() && gateway.Addr().Is4() {
return gateway, nil
}
if dest.Addr().Is6() && gateway.Addr().Is6() {
return gateway, nil
}
}
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}

View File

@@ -95,29 +95,6 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
@@ -294,6 +295,7 @@ func (t *tun) activate6(network netip.Prefix) error {
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
@@ -549,32 +551,16 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
// WriteBatch writes packets individually (no batching for non-Linux platforms)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for non-Linux platforms (no batching)
func (t *tun) BatchSize() int {
return 1
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -105,36 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}
// BatchRead reads a single packet (batch size 1 for disabled tun)
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for disabled tun)
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for disabled tun (no batching)
func (t *disabledTun) BatchSize() int {
return 1
}
func (t *disabledTun) Close() error {
if t.read != nil {
close(t.read)

View File

@@ -10,9 +10,11 @@ import (
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
@@ -20,18 +22,12 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
)
type fiodgnameArg struct {
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
}
type ifreqRename struct {
Name [unix.IFNAMSIZ]byte
Name [16]byte
Data uintptr
}
type ifreqDestroy struct {
Name [unix.IFNAMSIZ]byte
Name [16]byte
pad [16]byte
}
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
io.ReadWriteCloser
}
func (t *tun) Close() error {
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil {
t.l.WithError(err).Error("Error closing device")
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
t.devFd = -1
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
return err
}
}()
defer syscall.Close(s)
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
}
if err != nil {
return nil, err
}
// Read the name of the interface
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil {
return nil, err
}
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
}
err = t.reload(c, true)
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
if cidr.Addr().Is6() {
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
return fmt.Errorf("unknown address type %v", cidr)
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -450,36 +256,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
// BatchRead reads a single packet (batch size 1 for FreeBSD)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for FreeBSD)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for FreeBSD (no batching)
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
@@ -488,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -510,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -526,120 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -151,29 +151,6 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"sync/atomic"
"time"
"unsafe"
@@ -20,12 +21,10 @@ import (
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
wgDevice wgtun.Device
fd int
Device string
vpnNetworks []netip.Prefix
@@ -66,154 +65,59 @@ type ifreqQLEN struct {
pad [8]byte
}
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
// This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct {
dev wgtun.Device
buf []byte // Reusable buffer for single packet reads
}
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
// Use wireguard Device's batch API for single packet
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := w.dev.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
// Pass all buffers to WireGuard's batch write
return w.dev.Write(bufs, offset)
}
func (w *wgDeviceWrapper) Close() error {
return w.dev.Close()
}
// BatchRead implements batching for multiqueue readers
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// The zero here is offset.
return w.dev.Read(bufs, sizes, 0)
}
// BatchSize returns the optimal batch size
func (w *wgDeviceWrapper) BatchSize() int {
return w.dev.BatchSize()
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
}
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
file := wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.wgDevice = wgDev
t.Device = name
t.Device = "tun0"
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
// Check if /dev/net/tun exists, create if needed (for docker containers)
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
if err := os.MkdirAll("/dev/net", 0755); err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
}
devName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create TUN device manually to support multiqueue
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], devName)
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
}
name, err := wgDev.Name()
if err != nil {
_ = wgDev.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
// file is now owned by wgDev, get a new reference
file = wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.wgDevice = wgDev
t.Device = name
return t, nil
@@ -312,44 +216,22 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
// MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking mode - CRITICAL for proper netpoller integration
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Get MTU from main device
mtu := t.MaxMTU
if mtu == 0 {
mtu = DefaultMTU
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard Device from the file descriptor (just like the main device)
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
}
// Return a wrapper that uses the wireguard Device for all I/O
return &wgDeviceWrapper{dev: wgDev}, nil
return file, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -357,68 +239,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int
maximum := len(b)
@@ -441,22 +262,6 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// WriteBatch writes multiple packets to the TUN device in a single syscall
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Write(bufs, offset)
}
// Fallback: write individually (shouldn't happen in normal operation)
for i, buf := range bufs {
_, err := t.Write(buf)
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -488,6 +293,7 @@ func (t *tun) addIPs(link netlink.Link) error {
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
@@ -555,11 +361,6 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
if err = t.addIPs(link); err != nil {
return err
}
@@ -837,11 +638,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
@@ -869,10 +665,6 @@ func (t *tun) Close() error {
close(t.routeChan)
}
if t.wgDevice != nil {
_ = t.wgDevice.Close()
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}

View File

@@ -4,12 +4,13 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
@@ -19,42 +20,11 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
SIOCAIFADDR_IN6 = 0x8080696b
TUNSIFHEAD = 0x80047442
TUNSIFMODE = 0x80047458
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
}
type tun struct {
@@ -64,18 +34,40 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
io.ReadWriteCloser
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
@@ -85,19 +77,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
f: os.NewFile(uintptr(fd), ""),
fd: fd,
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifr := ifreq{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
return err
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime = addrLifetime{
Vltime: 0xffffffff,
Pltime: 0xffffffff,
}
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
mode := int32(unix.IFF_BROADCAST)
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
if err != nil {
return fmt.Errorf("failed to set tun device mode: %w", err)
}
v := 1
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
if err != nil {
return fmt.Errorf("failed to set tun device head: %w", err)
}
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -390,52 +191,27 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -448,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -464,109 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -4,50 +4,23 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
SIOCAIFADDR_IN6 = 0x8080691a
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime [2]uint32
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
@@ -55,42 +28,44 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
io.ReadWriteCloser
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
f: os.NewFile(uintptr(fd), ""),
fd: fd,
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -112,154 +87,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
err = addRoute(cidr, t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime[0] = 0xffffffff
req.Lifetime[1] = 0xffffffff
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
@@ -297,65 +124,63 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -367,9 +192,10 @@ func (t *tun) removeRoutes(routes []Route) error {
if !r.Install {
continue
}
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -378,115 +204,52 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
func (t *tun) Name() string {
return t.Device
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
return 0, fmt.Errorf("unable to determine IP version from packet")
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
copy(buf[4:], from)
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
}

View File

@@ -132,29 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *TestTun) BatchSize() int {
return 1
}

View File

@@ -6,6 +6,7 @@ package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
@@ -233,36 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
// BatchRead reads a single packet (batch size 1 for Windows)
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for Windows)
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for Windows (no batching)
func (t *winTun) BatchSize() int {
return 1
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.

View File

@@ -46,36 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}
// BatchRead reads a single packet (batch size 1 for UserDevice)
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := d.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for UserDevice)
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := d.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for UserDevice (no batching)
func (d *UserDevice) BatchSize() int {
return 1
}
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
return d.inboundReader, d.outboundWriter
}

View File

@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
}
// Set up the parameters which include the peer's public key

44
pki.go
View File

@@ -100,36 +100,41 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
} else {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new v1 cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
//adding certs is fine, actually
} else {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
@@ -137,25 +142,13 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
}
} else if currentState.v2Cert != nil {
//newState.v1Cert is non-nil bc empty certstates aren't permitted
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
nil,
)
}
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before
@@ -180,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
@@ -365,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
cs.initiatingVersion = dv
}

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -202,9 +202,7 @@ type RemoteList struct {
cache map[netip.Addr]*cache
hr *hostnamesResults
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
@@ -215,7 +213,7 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
@@ -370,15 +368,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c
}
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
@@ -588,7 +577,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr)
}

View File

@@ -6,7 +6,6 @@ import (
"log"
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"strconv"
"time"

View File

@@ -13,21 +13,12 @@ type EncReader func(
payload []byte,
)
type EncBatchReader func(
addrs []netip.AddrPort,
payloads [][]byte,
count int,
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader)
ListenOutBatch(r EncBatchReader)
WriteTo(b []byte, addr netip.AddrPort) error
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
ReloadConfig(c *config.C)
BatchSize() int
Close() error
}
@@ -42,21 +33,12 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) BatchSize() int {
return 1
}
func (NoopConn) Close() error {
return nil
}

View File

@@ -1,5 +0,0 @@
package udp
import "errors"
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

View File

@@ -3,62 +3,20 @@
package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/unix"
)
type StdConn struct {
*net.UDPConn
isV4 bool
sysFd uintptr
l *logrus.Logger
}
var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
c := &StdConn{UDPConn: uc, l: l}
rc, err := uc.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to open udp socket: %w", err)
}
err = rc.Control(func(fd uintptr) {
c.sysFd = fd
})
if err != nil {
return nil, fmt.Errorf("failed to get udp fd: %w", err)
}
la, err := c.LocalAddr()
if err != nil {
return nil, err
}
c.isV4 = la.Addr().Is4()
return c, nil
}
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
@@ -85,155 +43,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
//go:linkname sendto golang.org/x/sys/unix.sendto
//go:noescape
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
var sa unsafe.Pointer
var addrLen int32
if u.isV4 {
if ap.Addr().Is6() {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
} else {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet6
}
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
// See https://github.com/golang/go/issues/73919
for {
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
if err == nil {
// Written, get out before the error handling
return nil
}
if errors.Is(err, syscall.EINTR) {
// Write was interrupted, retry
continue
}
if errors.Is(err, syscall.EAGAIN) {
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
}
if errors.Is(err, syscall.EBADF) {
return net.ErrClosed
}
return &net.OpError{Op: "sendto", Err: err}
}
}
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
func (u *GenericConn) Rebind() error {
rc, err := u.UDPConn.SyscallConn()
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
addr, ok := netip.AddrFromSlice(v.IP)
if !ok {
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
}
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
func (u *StdConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
// ListenOutBatch - fallback to single-packet reads for Darwin
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
func (u *StdConn) BatchSize() int {
return 1
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
} else {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
return err
}
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket")
}
return nil
})
}

View File

@@ -1,7 +1,6 @@
//go:build (!linux || android) && !e2e_testing && !darwin
//go:build (!linux || android) && !e2e_testing
// +build !linux android
// +build !e2e_testing
// +build !darwin
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
@@ -85,42 +84,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
// ListenOutBatch - fallback to single-packet reads for generic platforms
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
// WriteMulti sends multiple packets - fallback implementation
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *GenericConn) BatchSize() int {
return 1
}
func (u *GenericConn) Rebind() error {
return nil
}

View File

@@ -22,11 +22,6 @@ type StdConn struct {
isV4 bool
l *logrus.Logger
batch int
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
writeMsgs []rawMessage
writeIovecs []iovec
writeNames [][]byte
}
func maybeIPV4(ip net.IP) (net.IP, bool) {
@@ -74,26 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
c.writeMsgs = make([]rawMessage, batch)
c.writeIovecs = make([]iovec, batch)
c.writeNames = make([][]byte, batch)
for i := range c.writeMsgs {
// Allocate for IPv6 size (larger than IPv4, works for both)
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
// Point to the iovec in the slice
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
c.writeMsgs[i].Hdr.Iovlen = 1
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
// Namelen will be set appropriately in writeMulti4/writeMulti6
}
return c, err
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) Rebind() error {
@@ -151,8 +127,6 @@ func (u *StdConn) ListenOut(r EncReader) {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
for {
n, err := read(msgs)
if err != nil {
@@ -160,8 +134,6 @@ func (u *StdConn) ListenOut(r EncReader) {
return
}
udpBatchHist.Update(int64(n))
for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
@@ -174,46 +146,6 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
// Pre-allocate slices for batch callback
addrs := make([]netip.AddrPort, u.batch)
payloads := make([][]byte, u.batch)
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpBatchHist.Update(int64(n))
// Prepare batch data
for i := 0; i < n; i++ {
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payloads[i] = buffers[i][:msgs[i].Len]
}
// Call batch callback with all packets
r(addrs, payloads, n)
}
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
@@ -262,19 +194,6 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
if len(packets) != len(addrs) {
return 0, fmt.Errorf("packets and addrs length mismatch")
}
if len(packets) == 0 {
return 0, nil
}
if u.isV4 {
return u.writeMulti4(packets, addrs)
}
return u.writeMulti6(packets, addrs)
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@@ -302,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var rsa unix.RawSockaddrInet4
@@ -329,123 +248,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
}
}
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
if !addrs[pktIdx].Addr().Is4() {
return sent + i, ErrInvalidIPv6RemoteForSocket
}
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET
rsa.Addr = addrs[pktIdx].Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv4
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET6
rsa.Addr = addrs[pktIdx].Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv6
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
@@ -503,10 +305,6 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil
}
func (u *StdConn) BatchSize() int {
return u.batch
}
func (u *StdConn) Close() error {
return syscall.Close(u.sysFd)
}

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint
Len uint32
}
type msghdr struct {
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint
Len uint64
}
type msghdr struct {
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -92,25 +92,6 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
// Enable v4 for this socket
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
// the UDP receive error returns with these ioctl calls.
ret := uint32(0)
flag := uint32(0)
size := uint32(unsafe.Sizeof(flag))
err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
ret = 0
flag = 0
size = uint32(unsafe.Sizeof(flag))
SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
err = u.rx.Open()
if err != nil {
return err
@@ -141,13 +122,9 @@ func (u *RIOConn) ListenOut(r EncReader) {
// Just read one packet at a time
n, rua, err := u.receive(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
}
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
}

View File

@@ -116,31 +116,6 @@ func (u *TesterConn) ListenOut(r EncReader) {
}
}
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
p, ok := <-u.RxPackets
if !ok {
return
}
addrs[0] = p.From
payloads[0] = p.Data
r(addrs, payloads, 1)
}
}
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []Conn) func() {
@@ -152,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) BatchSize() int {
return 1
}
func (u *TesterConn) Rebind() error {
return nil
}