mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
Compare commits
24 Commits
multiport
...
cross-stac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9101b62162 | ||
|
|
d2cb854bff | ||
|
|
9bf9fb14bc | ||
|
|
0f53b8a6ef | ||
|
|
7797927401 | ||
|
|
45c1d3eab3 | ||
|
|
634181ba66 | ||
|
|
eb89839d13 | ||
|
|
fb7f0c3657 | ||
|
|
b1f53d8d25 | ||
|
|
8824eeaea2 | ||
|
|
071589f7c7 | ||
|
|
f1e992f6dd | ||
|
|
1ea5f776d7 | ||
|
|
4cdeb284ef | ||
|
|
5cccd39465 | ||
|
|
8196c22b5a | ||
|
|
65cc253c19 | ||
|
|
73cfa7b5b1 | ||
|
|
768325c9b4 | ||
|
|
932e329164 | ||
|
|
4bea299265 | ||
|
|
5cff83b282 | ||
|
|
7da79685ff |
4
.github/workflows/gofmt.yml
vendored
4
.github/workflows/gofmt.yml
vendored
@@ -16,9 +16,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -12,9 +12,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -35,9 +35,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -68,9 +68,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Import certificates
|
||||
|
||||
4
.github/workflows/smoke-extra.yml
vendored
4
.github/workflows/smoke-extra.yml
vendored
@@ -22,9 +22,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: add hashicorp source
|
||||
|
||||
4
.github/workflows/smoke.yml
vendored
4
.github/workflows/smoke.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
|
||||
20
.github/workflows/test.yml
vendored
20
.github/workflows/test.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: v2.1
|
||||
version: v2.5
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
@@ -58,9 +58,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -79,9 +79,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -100,9 +100,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: v2.1
|
||||
version: v2.5
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
||||
|
||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||
|
||||
rawMap, ok := value.(map[any]any)
|
||||
rawMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||
}
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
for rawCIDR, rawValue := range rawMap {
|
||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
||||
}
|
||||
|
||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||
rawMap, ok := raw.(map[any]any)
|
||||
rawMap, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ type Certificate interface {
|
||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||
PublicKey() []byte
|
||||
|
||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
||||
MarshalPublicKeyPEM() []byte
|
||||
|
||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||
Curve() Curve
|
||||
|
||||
@@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
||||
case Version2:
|
||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||
default:
|
||||
//TODO: CERT-V2 make a static var
|
||||
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||
return nil, ErrUnknownVersion
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
|
||||
return c.details.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV1) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV1_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version1, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.details.curve = Curve_P256
|
||||
nc.details.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV1_Expired(t *testing.T) {
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
|
||||
@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
|
||||
return c.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV2) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV2_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
publicKey: pubKey,
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version2, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.curve = Curve_P256
|
||||
nc.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV2_Expired(t *testing.T) {
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
|
||||
@@ -20,6 +20,7 @@ var (
|
||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
||||
|
||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||
|
||||
52
cert/pem.go
52
cert/pem.go
@@ -7,19 +7,26 @@ import (
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
const (
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
const ( //cert banners
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
)
|
||||
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
const ( //key-agreement-key banners
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
)
|
||||
|
||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
||||
const ( //signing key banners
|
||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
)
|
||||
|
||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
|
||||
}
|
||||
|
||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||
if c.IsCA() {
|
||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
} else {
|
||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_CURVE25519
|
||||
case P256PublicKeyBanner:
|
||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
||||
// Uncompressed
|
||||
expectedLen = 65
|
||||
curve = Curve_P256
|
||||
|
||||
@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
oldPubP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
|
||||
12
cert/sign.go
12
cert/sign.go
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
||||
}
|
||||
return t.SignWith(signer, curve, sp)
|
||||
case Curve_P256:
|
||||
pk := &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
},
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||
D: new(big.Int).SetBytes(key),
|
||||
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||
sp := func(certBytes []byte) ([]byte, error) {
|
||||
// We need to hash first for ECDSA
|
||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||
|
||||
@@ -356,7 +356,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo, primary) {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
decision = swapPrimary
|
||||
} else {
|
||||
// migrate the relays to the primary, if in use.
|
||||
@@ -447,7 +447,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time
|
||||
return inactiveDuration, true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||
// Let's sort this out.
|
||||
|
||||
@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
|
||||
addrMap: map[netip.Addr]*RemoteList{},
|
||||
queryChan: make(chan netip.Addr, 10),
|
||||
}
|
||||
lighthouses := map[netip.Addr]struct{}{}
|
||||
lighthouses := []netip.Addr{}
|
||||
staticList := map[netip.Addr]struct{}{}
|
||||
|
||||
lh.lighthouses.Store(&lighthouses)
|
||||
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
|
||||
return d.publicKey
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
||||
}
|
||||
|
||||
func (d *dummyCert) Signature() []byte {
|
||||
return d.signature
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ type m = map[string]any
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||
}
|
||||
|
||||
if len(vpnNetworks) == 0 {
|
||||
panic("no vpn networks")
|
||||
}
|
||||
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalPEM()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -55,3 +56,50 @@ func TestDropInactiveTunnels(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCrossStackRelaysWork(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||
|
||||
//myVpnV4 := myVpnIpNet[0]
|
||||
myVpnV6 := myVpnIpNet[1]
|
||||
relayVpnV4 := relayVpnIpNet[0]
|
||||
relayVpnV6 := relayVpnIpNet[1]
|
||||
theirVpnV6 := theirVpnIpNet[0]
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||
|
||||
t.Log("reply?")
|
||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||
p = r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
//t.Log("finish up")
|
||||
//myControl.Stop()
|
||||
//theirControl.Stop()
|
||||
//relayControl.Stop()
|
||||
}
|
||||
|
||||
241
firewall_test.go
241
firewall_test.go
@@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||
require.NoError(t, err)
|
||||
|
||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
@@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
@@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
anyIp6, err := netip.ParsePrefix("::/0")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
c := dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
f := &Firewall{}
|
||||
ft := FirewallTable{
|
||||
@@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||
|
||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
@@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
ip := netip.MustParsePrefix("fd99::99/128")
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on name", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
network := netip.MustParsePrefix("fd12::34/120")
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
@@ -510,6 +692,50 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
|
||||
},
|
||||
}
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
|
||||
},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.0.2.1"),
|
||||
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
ml := func(m map[string]struct{}, a [][]string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
@@ -727,6 +953,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr ipv6
|
||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr ipv6
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -1,8 +1,6 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.1
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2
|
||||
|
||||
@@ -459,7 +459,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||
}
|
||||
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
return false
|
||||
|
||||
@@ -292,6 +292,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||
continue
|
||||
}
|
||||
|
||||
m := NebulaControl{
|
||||
@@ -301,37 +302,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !hm.f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
b := hm.f.myVpnAddrs[0].As4()
|
||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = vpnIp.As4()
|
||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
|
||||
case cert.Version2:
|
||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
|
||||
default:
|
||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||
err = errors.New("unknown certificate version found while creating relay")
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
hostinfo.logger(hm.l).WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
} else {
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayFrom": m.GetRelayFrom(),
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": idx,
|
||||
"relay": relay}).
|
||||
@@ -357,39 +346,27 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||
}
|
||||
|
||||
var err error
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !hm.f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
b := hm.f.myVpnAddrs[0].As4()
|
||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = vpnIp.As4()
|
||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
|
||||
case cert.Version2:
|
||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
|
||||
default:
|
||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||
err = errors.New("unknown certificate version found while creating relay")
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
|
||||
continue
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
hostinfo.logger(hm.l).WithError(err).Error("Failed to marshal Control message to create relay")
|
||||
} else {
|
||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayFrom": m.GetRelayFrom(),
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||
"relay": relay}).
|
||||
@@ -724,3 +701,32 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
|
||||
func hsTimeout(tries int64, interval time.Duration) time.Duration {
|
||||
return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
|
||||
}
|
||||
|
||||
var errNoRelayTooOld = errors.New("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
|
||||
func buildRelayInfoCertV1(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
|
||||
relayFrom := myVpnNetworks[0].Addr()
|
||||
if !relayFrom.Is4() {
|
||||
return errNoRelayTooOld
|
||||
}
|
||||
if !peerVpnIp.Is4() {
|
||||
return errNoRelayTooOld
|
||||
}
|
||||
|
||||
b := relayFrom.As4()
|
||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = peerVpnIp.As4()
|
||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRelayInfoCertV2(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
|
||||
for i := range myVpnNetworks {
|
||||
if myVpnNetworks[i].Contains(peerVpnIp) {
|
||||
m.RelayFromAddr = netAddrToProtoAddr(myVpnNetworks[i].Addr())
|
||||
m.RelayToAddr = netAddrToProtoAddr(peerVpnIp)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("cannot establish relay, no networks in common")
|
||||
}
|
||||
|
||||
22
hostmap.go
22
hostmap.go
@@ -17,12 +17,10 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// const ProbeLen = 100
|
||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||
const MaxRemotes = 10
|
||||
const maxRecvError = 4
|
||||
|
||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||
@@ -225,8 +223,7 @@ type HostInfo struct {
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Lite
|
||||
@@ -515,13 +512,16 @@ func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
||||
return hm.queryVpnAddr(vpnIp, nil)
|
||||
}
|
||||
|
||||
var errUnableToFindHost = errors.New("unable to find host")
|
||||
var errUnableToFindHostWithRelay = errors.New("unable to find host with relay")
|
||||
|
||||
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
h, ok := hm.Hosts[relayHostIp]
|
||||
if !ok {
|
||||
return nil, nil, errors.New("unable to find host")
|
||||
return nil, nil, errUnableToFindHost
|
||||
}
|
||||
|
||||
for h != nil {
|
||||
@@ -534,7 +534,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
|
||||
h = h.next
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("unable to find host with relay")
|
||||
return nil, nil, errUnableToFindHostWithRelay
|
||||
}
|
||||
|
||||
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
||||
@@ -733,13 +733,6 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
if i.recvError.Add(1) >= maxRecvError {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
@@ -748,7 +741,8 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
|
||||
i.networks = new(bart.Lite)
|
||||
for _, network := range networks {
|
||||
i.networks.Insert(network)
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
i.networks.Insert(nprefix)
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
|
||||
249
lighthouse.go
249
lighthouse.go
@@ -24,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
var ErrHostNotKnown = errors.New("host not known")
|
||||
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
||||
|
||||
type LightHouse struct {
|
||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||
@@ -56,7 +57,7 @@ type LightHouse struct {
|
||||
// staticList exists to avoid having a bool in each addrMap entry
|
||||
// since static should be rare
|
||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||
lighthouses atomic.Pointer[[]netip.Addr]
|
||||
|
||||
interval atomic.Int64
|
||||
updateCancel context.CancelFunc
|
||||
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
l: l,
|
||||
}
|
||||
lighthouses := make(map[netip.Addr]struct{})
|
||||
lighthouses := make([]netip.Addr, 0)
|
||||
h.lighthouses.Store(&lighthouses)
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
h.staticList.Store(&staticList)
|
||||
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
||||
return *lh.staticList.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
||||
return *lh.lighthouses.Load()
|
||||
}
|
||||
|
||||
@@ -306,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("lighthouse.hosts") {
|
||||
lhMap := make(map[netip.Addr]struct{})
|
||||
err := lh.parseLighthouses(c, lhMap)
|
||||
lhList, err := lh.parseLighthouses(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lh.lighthouses.Store(&lhMap)
|
||||
lh.lighthouses.Store(&lhList)
|
||||
if !initial {
|
||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||
lh.l.Info("lighthouse.hosts has changed")
|
||||
@@ -346,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if lh.amLighthouse && len(lhs) != 0 {
|
||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||
}
|
||||
out := make([]netip.Addr, len(lhs))
|
||||
|
||||
for i, host := range lhs {
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
}
|
||||
lhMap[addr] = struct{}{}
|
||||
out[i] = addr
|
||||
}
|
||||
|
||||
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||
if !lh.amLighthouse && len(out) == 0 {
|
||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||
}
|
||||
|
||||
staticList := lh.GetStaticHostList()
|
||||
for lhAddr, _ := range lhMap {
|
||||
if _, ok := staticList[lhAddr]; !ok {
|
||||
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
||||
for i := range out {
|
||||
if _, ok := staticList[out[i]]; !ok {
|
||||
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||
@@ -486,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
// Add an entry if we don't already have one
|
||||
return lh.unlockedGetRemoteList(vpnAddrs)
|
||||
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
||||
}
|
||||
|
||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||
@@ -519,11 +520,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
||||
}
|
||||
|
||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||
// First we check the static mapping
|
||||
// and do nothing if it is there
|
||||
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
||||
return
|
||||
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
||||
staticList := lh.GetStaticHostList()
|
||||
for _, addr := range allVpnAddrs {
|
||||
if _, ok := staticList[addr]; ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// None of the VpnAddrs were present. Now we can do the deletes.
|
||||
lh.Lock()
|
||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||
if ok {
|
||||
@@ -565,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
||||
am.unlockedSetHostnamesResults(hr)
|
||||
|
||||
for _, addrPort := range hr.GetAddrs() {
|
||||
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
||||
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
@@ -627,23 +632,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||
}
|
||||
|
||||
// unlockedGetRemoteList
|
||||
// assumes you have the lh lock
|
||||
// unlockedGetRemoteList assumes you have the lh lock
|
||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||
am, ok := lh.addrMap[allAddrs[0]]
|
||||
if !ok {
|
||||
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
||||
for _, addr := range allAddrs {
|
||||
lh.addrMap[addr] = am
|
||||
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
||||
for i, addr := range allAddrs {
|
||||
am, ok := lh.addrMap[addr]
|
||||
if ok {
|
||||
if i != 0 {
|
||||
lh.addrMap[allAddrs[0]] = am
|
||||
}
|
||||
return am
|
||||
}
|
||||
}
|
||||
|
||||
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
||||
for _, addr := range allAddrs {
|
||||
lh.addrMap[addr] = am
|
||||
}
|
||||
return am
|
||||
}
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow {
|
||||
@@ -698,19 +710,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
||||
return true
|
||||
l := lh.GetLighthouses()
|
||||
for i := range l {
|
||||
if l[i] == vpnAddr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
||||
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||
l := lh.GetLighthouses()
|
||||
for _, a := range vpnAddr {
|
||||
if _, ok := l[a]; ok {
|
||||
return true
|
||||
for i := range vpnAddrs {
|
||||
for j := range l {
|
||||
if l[j] == vpnAddrs[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -752,7 +767,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
queried := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for lhVpnAddr := range lighthouses {
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
v = hi.ConnectionState.myCert.Version()
|
||||
@@ -870,7 +885,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
updated := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for lhVpnAddr := range lighthouses {
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
var v cert.Version
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
@@ -928,7 +943,6 @@ func (lh *LightHouse) SendUpdate() {
|
||||
V4AddrPorts: v4,
|
||||
V6AddrPorts: v6,
|
||||
RelayVpnAddrs: relays,
|
||||
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1048,19 +1062,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
return
|
||||
}
|
||||
|
||||
useVersion := cert.Version1
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = 1
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = 2
|
||||
} else {
|
||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||
Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1069,9 +1083,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostQueryReply
|
||||
if useVersion == cert.Version1 {
|
||||
if !queryVpnAddr.Is4() {
|
||||
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
||||
}
|
||||
b := queryVpnAddr.As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||
} else {
|
||||
@@ -1116,8 +1127,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
if ok {
|
||||
whereToPunch = newDest
|
||||
} else {
|
||||
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
||||
//choosing to do nothing for now, but maybe we return an error?
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1176,19 +1188,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
||||
if !r.Is4() {
|
||||
continue
|
||||
}
|
||||
|
||||
b = r.As4()
|
||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||
}
|
||||
|
||||
} else if v == cert.Version2 {
|
||||
for _, r := range c.relay.relay {
|
||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||
}
|
||||
|
||||
} else {
|
||||
//TODO: CERT-V2 don't panic
|
||||
panic("unsupported version")
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1198,18 +1208,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
return
|
||||
}
|
||||
|
||||
lhh.lh.Lock()
|
||||
|
||||
var certVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
certVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
||||
}
|
||||
return
|
||||
}
|
||||
relays := n.Details.GetRelays()
|
||||
|
||||
lhh.lh.Lock()
|
||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
@@ -1234,27 +1242,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
return
|
||||
}
|
||||
|
||||
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
||||
var detailsVpnAddr netip.Addr
|
||||
useVersion := cert.Version1
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
var useVersion cert.Version
|
||||
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
detailsVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = cert.Version1
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = cert.Version2
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
||||
}
|
||||
return
|
||||
detailsVpnAddr = netip.Addr{}
|
||||
useVersion = cert.Version2
|
||||
}
|
||||
|
||||
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
||||
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
||||
//Simple check that the host sent this not someone else
|
||||
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||
}
|
||||
@@ -1268,24 +1273,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
|
||||
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||
am.Unlock()
|
||||
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||
|
||||
if useVersion == cert.Version1 {
|
||||
switch useVersion {
|
||||
case cert.Version1:
|
||||
if !fromVpnAddrs[0].Is4() {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||
return
|
||||
}
|
||||
vpnAddrB := fromVpnAddrs[0].As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||
} else if useVersion == cert.Version2 {
|
||||
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
||||
} else {
|
||||
case cert.Version2:
|
||||
// do nothing, we want to send a blank message
|
||||
default:
|
||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||
return
|
||||
}
|
||||
@@ -1303,13 +1308,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||
//maybe one day we'll have a better idea, if it matters.
|
||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||
return
|
||||
}
|
||||
|
||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
empty := []byte{0}
|
||||
punch := func(vpnPeer netip.AddrPort) {
|
||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
||||
if !vpnPeer.IsValid() {
|
||||
return
|
||||
}
|
||||
@@ -1321,48 +1333,31 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
}()
|
||||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
var logVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
logVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
punch(protoV4AddrPortToNetAddrPort(a))
|
||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
punch(protoV6AddrPortToNetAddrPort(a))
|
||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
// of a double nat or other difficult scenario, this may help establish
|
||||
// a tunnel.
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
||||
}
|
||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -1430,7 +1425,7 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
|
||||
return relays
|
||||
}
|
||||
|
||||
// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
|
||||
// findNetworkUnion returns the first netip.Addr of addrs contained in the list of provided netip.Prefix, if able
|
||||
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
|
||||
for i := range prefixes {
|
||||
for j := range addrs {
|
||||
@@ -1441,3 +1436,27 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
||||
}
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
||||
if d.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
||||
detailsVpnAddr := netip.AddrFrom4(b)
|
||||
return detailsVpnAddr, cert.Version1, nil
|
||||
} else if d.VpnAddr != nil {
|
||||
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
||||
return detailsVpnAddr, cert.Version2, nil
|
||||
} else {
|
||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
||||
}
|
||||
}
|
||||
|
||||
func (d *NebulaControl) GetRelayFrom() netip.Addr {
|
||||
if d.OldRelayFromAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], d.OldRelayFromAddr)
|
||||
return netip.AddrFrom4(b)
|
||||
} else {
|
||||
return protoAddrToNetAddr(d.RelayFromAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
|
||||
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
||||
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
||||
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
"10.128.0.42": []any{"1.2.3.4:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out := lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//bolt on a lower numbered primary IP
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
||||
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
||||
lh.addrMap[testSameHostNotStatic] = am
|
||||
out.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out = lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//test that we actually have the static entry for BOTH:
|
||||
out2 := lh.Query(testSameHostNotStatic)
|
||||
assert.Same(t, out2, out)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
||||
//verify
|
||||
out = lh.Query(testSameHostNotStatic)
|
||||
assert.NotNil(t, out)
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil query for the static host")
|
||||
}
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
}
|
||||
|
||||
func TestLighthouse_DeletesWork(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
testHost := netip.MustParseAddr("10.128.0.42")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//insert the host
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
||||
am.vpnAddrs = []netip.Addr{testHost}
|
||||
am.addrs = []netip.AddrPort{myUdpAddr2}
|
||||
lh.addrMap[testHost] = am
|
||||
am.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the entry:
|
||||
out := lh.Query(testHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
||||
//verify
|
||||
out = lh.Query(testHost)
|
||||
assert.Nil(t, out)
|
||||
}
|
||||
|
||||
18
outside.go
18
outside.go
@@ -254,16 +254,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
||||
|
||||
}
|
||||
|
||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||
// If connectionstate exists and the replay protector allows, process packet
|
||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
||||
if ci == nil {
|
||||
if addr.IsValid() {
|
||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||
return false
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -537,10 +539,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.RecvErrorExceeded() {
|
||||
return
|
||||
}
|
||||
|
||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
func flipBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] ^= 0xFF
|
||||
}
|
||||
return b
|
||||
}
|
||||
func orBytes(a []byte, b []byte) []byte {
|
||||
ret := make([]byte, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
ret[i] = a[i] | b[i]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||
broadcast, _ := netip.AddrFromSlice(
|
||||
orBytes(
|
||||
cidr.Addr().AsSlice(),
|
||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||
),
|
||||
)
|
||||
return broadcast
|
||||
}
|
||||
|
||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||
for _, gateway := range gateways {
|
||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||
return gateway, nil
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||
Flags: _IN6_IFF_NODAD,
|
||||
}
|
||||
|
||||
@@ -554,13 +552,3 @@ func (t *tun) Name() string {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
@@ -10,11 +10,9 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
@@ -22,12 +20,18 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||
FIODGNAME = 0x80106678
|
||||
FIODGNAME = 0x80106678
|
||||
TUNSIFMODE = 0x8004745e
|
||||
TUNSIFHEAD = 0x80047460
|
||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||
IN6_IFF_NODAD = 0x0020
|
||||
)
|
||||
|
||||
type fiodgnameArg struct {
|
||||
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
|
||||
}
|
||||
|
||||
type ifreqRename struct {
|
||||
Name [16]byte
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Data uintptr
|
||||
}
|
||||
|
||||
type ifreqDestroy struct {
|
||||
Name [16]byte
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU int32
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
}
|
||||
|
||||
io.ReadWriteCloser
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := make([]byte, 4)
|
||||
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, err
|
||||
} else if bytesRead < 4 {
|
||||
return 0, err
|
||||
} else {
|
||||
return bytesRead - 4, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
ipVer := from[0] >> 4
|
||||
var head []byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||
} else if ipVer == 6 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if t.devFd >= 0 {
|
||||
err := syscall.Close(t.devFd)
|
||||
if err != nil {
|
||||
return err
|
||||
t.l.WithError(err).Error("Error closing device")
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
t.devFd = -1
|
||||
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||
defer close(c)
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err == nil {
|
||||
defer syscall.Close(s)
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
}
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error destroying tunnel")
|
||||
}
|
||||
}()
|
||||
|
||||
// Destroy the interface
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
return err
|
||||
// wait up to 1 second so we start blocking at the ioctl
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var file *os.File
|
||||
var fd int
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName != "" {
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||
}
|
||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||
// If the device doesn't already exist, request a new one and rename it
|
||||
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawConn, err := file.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||
// Read the name of the interface
|
||||
var name [16]byte
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
|
||||
if ctrlErr == nil {
|
||||
// set broadcast mode and multicast
|
||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||
}
|
||||
|
||||
if ctrlErr == nil {
|
||||
// turn on link-layer mode, to support ipv6
|
||||
ifhead := uint32(1)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||
}
|
||||
|
||||
var name [16]byte
|
||||
var ctrlErr error
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
// Read the name of the interface
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
})
|
||||
if ctrlErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
|
||||
// If the name doesn't match the desired interface name, rename it now
|
||||
if ifName != deviceName {
|
||||
s, err := syscall.Socket(
|
||||
syscall.AF_INET,
|
||||
syscall.SOCK_DGRAM,
|
||||
syscall.IPPROTO_IP,
|
||||
)
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
if cidr.Addr().Is4() {
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: getBroadcast(cidr).As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
},
|
||||
VHid: 0,
|
||||
}
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
if cidr.Addr().Is6() {
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
Expire: 0,
|
||||
Preferred: 0,
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
Flags: IN6_IFF_NODAD,
|
||||
}
|
||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
// Setup our default MTU
|
||||
err := t.setMTU()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) setMTU() error {
|
||||
// Set the MTU on the device
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -306,3 +500,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
fmt.Println("DOING CHANGE")
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLinkAddr Gets the link address for the interface of the given name
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
switch m := m.(type) {
|
||||
case *netroute.InterfaceMessage:
|
||||
if m.Name == name {
|
||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||
if ok {
|
||||
return sa, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -293,7 +293,6 @@ func (t *tun) addIPs(link netlink.Link) error {
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
@@ -361,6 +360,11 @@ func (t *tun) Activate() error {
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
const modeNone = 1
|
||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||
}
|
||||
|
||||
if err = t.addIPs(link); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -638,6 +642,11 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Dst == nil {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||
|
||||
@@ -4,13 +4,12 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -20,11 +19,42 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type ifreqDestroy struct {
|
||||
Name [16]byte
|
||||
pad [16]byte
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080696b
|
||||
TUNSIFHEAD = 0x80047442
|
||||
TUNSIFMODE = 0x80047458
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
@@ -34,40 +64,18 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
f *os.File
|
||||
fd int
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
@@ -77,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifr := ifreq{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Read(func(fd uintptr) bool {
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := [4]byte{}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
if errno.Temporary() {
|
||||
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
||||
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
||||
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
||||
}
|
||||
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, nil
|
||||
} else if bytesRead < 4 {
|
||||
return 0, nil
|
||||
} else {
|
||||
return bytesRead - 4, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
ipVer := from[0] >> 4
|
||||
var head [4]byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
head[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Write(func(fd uintptr) bool {
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
// According to NetBSD documentation for TUN, writes will only return errors in which
|
||||
// this packet will never be delivered so just go on living life.
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime = addrLifetime{
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
mode := int32(unix.IFF_BROADCAST)
|
||||
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device mode: %w", err)
|
||||
}
|
||||
|
||||
v := 1
|
||||
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device head: %w", err)
|
||||
}
|
||||
|
||||
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -197,21 +396,23 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,10 +425,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -242,3 +441,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,23 +4,50 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080691a
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime [2]uint32
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
@@ -28,48 +55,46 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
||||
f *os.File
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.f.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
err = addRoute(cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime[0] = 0xffffffff
|
||||
req.Lifetime[1] = 0xffffffff
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
@@ -124,86 +297,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
@@ -213,43 +311,159 @@ func (t *tun) Name() string {
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
copy(buf[4:], from)
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
return n - 4, err
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
||||
}
|
||||
|
||||
// Set up the parameters which include the peer's public key
|
||||
|
||||
5
pki.go
5
pki.go
@@ -173,7 +173,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
|
||||
p.cs.Store(newState)
|
||||
|
||||
//TODO: CERT-V2 newState needs a stringer that does json
|
||||
if initial {
|
||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||
} else {
|
||||
@@ -359,7 +358,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||
}
|
||||
|
||||
//TODO: CERT-V2 make sure v2 has v1s address
|
||||
if v1.Networks()[0] != v2.Networks()[0] {
|
||||
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
||||
}
|
||||
|
||||
cs.initiatingVersion = dv
|
||||
}
|
||||
|
||||
@@ -155,6 +155,8 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
"vpnAddrs": h.vpnAddrs}).
|
||||
Info("handleCreateRelayResponse")
|
||||
|
||||
//peer == relayFrom
|
||||
//target == relayTo
|
||||
target := m.RelayToAddr
|
||||
targetAddr := protoAddrToNetAddr(target)
|
||||
|
||||
@@ -190,11 +192,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
||||
}
|
||||
|
||||
relayFrom := h.vpnAddrs[0]
|
||||
if v == cert.Version1 {
|
||||
peer := peerHostInfo.vpnAddrs[0]
|
||||
if !peer.Is4() {
|
||||
rm.l.WithField("relayFrom", peer).
|
||||
WithField("relayTo", target).
|
||||
WithField("relayTo", targetAddr).
|
||||
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
||||
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
||||
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
||||
@@ -207,7 +210,22 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
b = targetAddr.As4()
|
||||
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
} else {
|
||||
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
|
||||
ok = false
|
||||
peerNetworks := h.GetCert().Certificate.Networks()
|
||||
for i := range peerNetworks {
|
||||
if peerNetworks[i].Contains(targetAddr) {
|
||||
relayFrom = peerNetworks[i].Addr()
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": targetAddr}).
|
||||
Error("cannot establish relay, no networks in common")
|
||||
return
|
||||
}
|
||||
|
||||
resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||
resp.RelayToAddr = target
|
||||
}
|
||||
|
||||
@@ -218,8 +236,8 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": resp.RelayFromAddr,
|
||||
"relayTo": resp.RelayToAddr,
|
||||
"relayFrom": relayFrom,
|
||||
"relayTo": targetAddr,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
||||
@@ -313,8 +331,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
|
||||
msg, err := resp.Marshal()
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||
logMsg.WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
@@ -360,10 +377,10 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: index,
|
||||
}
|
||||
|
||||
relayFrom := h.vpnAddrs[0]
|
||||
if v == cert.Version1 {
|
||||
if !h.vpnAddrs[0].Is4() {
|
||||
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
||||
if !relayFrom.Is4() {
|
||||
rm.l.WithField("relayFrom", relayFrom).
|
||||
WithField("relayTo", target).
|
||||
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
||||
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
||||
@@ -372,23 +389,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
return
|
||||
}
|
||||
|
||||
b := h.vpnAddrs[0].As4()
|
||||
b := relayFrom.As4()
|
||||
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = target.As4()
|
||||
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
} else {
|
||||
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
|
||||
ok = false
|
||||
peerNetworks := h.GetCert().Certificate.Networks()
|
||||
for i := range peerNetworks {
|
||||
if peerNetworks[i].Contains(target) {
|
||||
relayFrom = peerNetworks[i].Addr()
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": target}).
|
||||
Error("cannot establish relay, no networks in common")
|
||||
return
|
||||
}
|
||||
|
||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||
req.RelayToAddr = netAddrToProtoAddr(target)
|
||||
}
|
||||
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
||||
logMsg.WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": h.vpnAddrs[0],
|
||||
"relayFrom": relayFrom,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
@@ -401,8 +432,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
if !ok {
|
||||
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
||||
logMsg.WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ type RemoteList struct {
|
||||
// The full list of vpn addresses assigned to this host
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
||||
addrs []netip.AddrPort
|
||||
|
||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||
@@ -201,8 +201,10 @@ type RemoteList struct {
|
||||
// For learned addresses, this is the vpnIp that sent the packet
|
||||
cache map[netip.Addr]*cache
|
||||
|
||||
hr *hostnamesResults
|
||||
shouldAdd func(netip.Addr) bool
|
||||
hr *hostnamesResults
|
||||
|
||||
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
||||
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
||||
|
||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||
// They should not be tried again during a handshake
|
||||
@@ -213,7 +215,7 @@ type RemoteList struct {
|
||||
}
|
||||
|
||||
// NewRemoteList creates a new empty RemoteList
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
||||
r := &RemoteList{
|
||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||
addrs: make([]netip.AddrPort, 0),
|
||||
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
||||
return c
|
||||
}
|
||||
|
||||
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
||||
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
||||
r.Lock()
|
||||
r.badRemotes = nil
|
||||
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
||||
copy(r.vpnAddrs, vpnAddrs)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||
func (r *RemoteList) ResetBlockedRemotes() {
|
||||
r.Lock()
|
||||
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
|
||||
|
||||
dnsAddrs := r.hr.GetAddrs()
|
||||
for _, addr := range dnsAddrs {
|
||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
||||
if !r.unlockedIsBad(addr) {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user