mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-24 01:14:25 +01:00
Compare commits
15 Commits
channels-s
...
cert-v2-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fa386bb70 | ||
|
|
f488873d70 | ||
|
|
c073eebe42 | ||
|
|
1a1255d557 | ||
|
|
32649de665 | ||
|
|
26a00a5647 | ||
|
|
888ba400b9 | ||
|
|
dc3081ea49 | ||
|
|
68bbb53b90 | ||
|
|
41273a94bb | ||
|
|
0946831f88 | ||
|
|
d2d3e21780 | ||
|
|
241b0a6d7f | ||
|
|
0721dde24b | ||
|
|
a6640b4540 |
3
bits.go
3
bits.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
|
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
length uint64
|
length uint64
|
||||||
current uint64
|
current uint64
|
||||||
@@ -44,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Error("rejected a packet (top) %d %d\n", b.current, i)
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,9 +58,6 @@ type Certificate interface {
|
|||||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||||
PublicKey() []byte
|
PublicKey() []byte
|
||||||
|
|
||||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
|
||||||
MarshalPublicKeyPEM() []byte
|
|
||||||
|
|
||||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||||
Curve() Curve
|
Curve() Curve
|
||||||
|
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
|
|||||||
return c.details.publicKey
|
return c.details.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV1) Signature() []byte {
|
func (c *certificateV1) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,7 +13,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Marshal(t *testing.T) {
|
func TestCertificateV1_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV1{
|
|
||||||
details: detailsV1{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
publicKey: pubKey,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version1, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.details.curve = Curve_P256
|
|
||||||
nc.details.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV1_Expired(t *testing.T) {
|
func TestCertificateV1_Expired(t *testing.T) {
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
|
|||||||
@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
|
|||||||
return c.publicKey
|
return c.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV2) Signature() []byte {
|
func (c *certificateV2) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV2_Marshal(t *testing.T) {
|
func TestCertificateV2_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV2{
|
|
||||||
details: detailsV2{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
publicKey: pubKey,
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version2, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.curve = Curve_P256
|
|
||||||
nc.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV2_Expired(t *testing.T) {
|
func TestCertificateV2_Expired(t *testing.T) {
|
||||||
nc := certificateV2{
|
nc := certificateV2{
|
||||||
details: detailsV2{
|
details: detailsV2{
|
||||||
|
|||||||
44
cert/pem.go
44
cert/pem.go
@@ -7,26 +7,19 @@ import (
|
|||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ( //cert banners
|
const (
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
)
|
|
||||||
|
|
||||||
const ( //key-agreement-key banners
|
|
||||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
|
||||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
|
||||||
const ( //signing key banners
|
|
||||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
|
||||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
|
||||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
|
||||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||||
|
|
||||||
|
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||||
|
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||||
|
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||||
|
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||||
@@ -58,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
|
||||||
if c.IsCA() {
|
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
} else {
|
|
||||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
switch curve {
|
switch curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
@@ -79,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|
||||||
switch curve {
|
|
||||||
case Curve_CURVE25519:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
|
||||||
case Curve_P256:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
k, r := pem.Decode(b)
|
k, r := pem.Decode(b)
|
||||||
if k == nil {
|
if k == nil {
|
||||||
@@ -103,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
|||||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||||
expectedLen = 32
|
expectedLen = 32
|
||||||
curve = Curve_CURVE25519
|
curve = Curve_CURVE25519
|
||||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
case P256PublicKeyBanner:
|
||||||
// Uncompressed
|
// Uncompressed
|
||||||
expectedLen = 65
|
expectedLen = 65
|
||||||
curve = Curve_P256
|
curve = Curve_P256
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
`)
|
|
||||||
oldPubP256Key := []byte(`# A good key
|
|
||||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
`)
|
`)
|
||||||
shortKey := []byte(`# A short key
|
shortKey := []byte(`# A short key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||||
|
|
||||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
|
|
||||||
// Success test case
|
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
|
||||||
assert.Len(t, k, 65)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
|
||||||
assert.Equal(t, Curve_P256, curve)
|
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 65)
|
assert.Len(t, k, 65)
|
||||||
|
|||||||
@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||||
|
nc := &cert.TBSCertificate{
|
||||||
|
Version: v,
|
||||||
|
Curve: c.Curve(),
|
||||||
|
Name: c.Name(),
|
||||||
|
Networks: c.Networks(),
|
||||||
|
UnsafeNetworks: c.UnsafeNetworks(),
|
||||||
|
Groups: c.Groups(),
|
||||||
|
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||||
|
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||||
|
PublicKey: c.PublicKey(),
|
||||||
|
IsCA: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, pem
|
||||||
|
}
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
191
cmd/gso/gso.go
191
cmd/gso/gso.go
@@ -1,191 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// UDP_SEGMENT enables GSO segmentation
|
|
||||||
UDP_SEGMENT = 103
|
|
||||||
// Maximum GSO segment size (typical MTU - headers)
|
|
||||||
maxGSOSize = 1400
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address")
|
|
||||||
gsoSize := flag.Int("gso", 1400, "GSO segment size")
|
|
||||||
totalSize := flag.Int("size", 14000, "Total payload size to send")
|
|
||||||
count := flag.Int("count", 1, "Number of packets to send")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *gsoSize > maxGSOSize {
|
|
||||||
log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve destination address
|
|
||||||
_, err := net.ResolveUDPAddr("udp", *destAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to resolve address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a raw UDP socket with GSO support
|
|
||||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to create socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
// Bind to a local address
|
|
||||||
localAddr := &unix.SockaddrInet4{
|
|
||||||
Port: 0, // Let the system choose a port
|
|
||||||
}
|
|
||||||
if err := unix.Bind(fd, localAddr); err != nil {
|
|
||||||
log.Fatalf("Failed to bind socket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Sending UDP packets with GSO enabled\n")
|
|
||||||
fmt.Printf("Destination: %s\n", *destAddr)
|
|
||||||
fmt.Printf("GSO segment size: %d bytes\n", *gsoSize)
|
|
||||||
fmt.Printf("Total payload size: %d bytes\n", *totalSize)
|
|
||||||
fmt.Printf("Number of packets: %d\n\n", *count)
|
|
||||||
|
|
||||||
// Create payload
|
|
||||||
payload := make([]byte, *totalSize)
|
|
||||||
for i := range payload {
|
|
||||||
payload[i] = byte(i % 256)
|
|
||||||
}
|
|
||||||
|
|
||||||
dest := netip.MustParseAddrPort(*destAddr)
|
|
||||||
|
|
||||||
//if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil {
|
|
||||||
// panic(err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
for i := 0; i < *count; i++ {
|
|
||||||
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Send error on packet %d: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i+1)%100 == 0 || i == *count-1 {
|
|
||||||
fmt.Printf("Sent %d packets\n", i+1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Printf("now, let's send without the correct ctrl header\n")
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
for i := 0; i < *count; i++ {
|
|
||||||
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Send error on packet %d: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i+1)%100 == 0 || i == *count-1 {
|
|
||||||
fmt.Printf("Sent %d packets\n", i+1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error {
|
|
||||||
msgs := make([]rawMessage, 0, 1)
|
|
||||||
iovs := make([]iovec, 0, 1)
|
|
||||||
names := make([][unix.SizeofSockaddrInet6]byte, 0, 1)
|
|
||||||
|
|
||||||
sent := 0
|
|
||||||
|
|
||||||
pkts := []BatchPacket{
|
|
||||||
{
|
|
||||||
Payload: payload,
|
|
||||||
Addr: addr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if len(pkt.Payload) == 0 {
|
|
||||||
sent++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, rawMessage{})
|
|
||||||
iovs = append(iovs, iovec{})
|
|
||||||
names = append(names, [unix.SizeofSockaddrInet6]byte{})
|
|
||||||
|
|
||||||
idx := len(msgs) - 1
|
|
||||||
msg := &msgs[idx]
|
|
||||||
iov := &iovs[idx]
|
|
||||||
name := &names[idx]
|
|
||||||
|
|
||||||
setIovecSlice(iov, pkt.Payload)
|
|
||||||
msg.Hdr.Iov = iov
|
|
||||||
msg.Hdr.Iovlen = 1
|
|
||||||
|
|
||||||
if withHeader {
|
|
||||||
setRawMessageControl(msg, buildGSOControlMessage(segSize)) //
|
|
||||||
} else {
|
|
||||||
setRawMessageControl(msg, nil) //
|
|
||||||
}
|
|
||||||
|
|
||||||
msg.Hdr.Flags = 0
|
|
||||||
|
|
||||||
nameLen, err := encodeSockaddr(name[:], pkt.Addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.Hdr.Name = &name[0]
|
|
||||||
msg.Hdr.Namelen = nameLen
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(msgs) == 0 {
|
|
||||||
return errors.New("nothing to write")
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
for offset < len(msgs) {
|
|
||||||
n, _, errno := unix.Syscall6(
|
|
||||||
unix.SYS_SENDMMSG,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
|
||||||
uintptr(len(msgs)-offset),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
if errno == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return &net.OpError{Op: "sendmmsg", Err: errno}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
offset += int(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildGSOControlMessage(segSize uint16) []byte {
|
|
||||||
control := make([]byte, unix.CmsgSpace(2))
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
setCmsgLen(hdr, unix.CmsgLen(2))
|
|
||||||
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
|
|
||||||
|
|
||||||
return control
|
|
||||||
}
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type iovec struct {
|
|
||||||
Base *byte
|
|
||||||
Len uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
type msghdr struct {
|
|
||||||
Name *byte
|
|
||||||
Namelen uint32
|
|
||||||
Pad0 [4]byte
|
|
||||||
Iov *iovec
|
|
||||||
Iovlen uint64
|
|
||||||
Control *byte
|
|
||||||
Controllen uint64
|
|
||||||
Flags int32
|
|
||||||
Pad1 [4]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type rawMessage struct {
|
|
||||||
Hdr msghdr
|
|
||||||
Len uint32
|
|
||||||
Pad0 [4]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type BatchPacket struct {
|
|
||||||
Payload []byte
|
|
||||||
Addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
|
||||||
if addr.Addr().Is4() {
|
|
||||||
if !addr.Addr().Is4() {
|
|
||||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
|
||||||
}
|
|
||||||
var sa unix.RawSockaddrInet4
|
|
||||||
sa.Family = unix.AF_INET
|
|
||||||
sa.Addr = addr.Addr().As4()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet4
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.RawSockaddrInet6
|
|
||||||
sa.Family = unix.AF_INET6
|
|
||||||
sa.Addr = addr.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet6
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
msg.Hdr.Control = nil
|
|
||||||
msg.Hdr.Controllen = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Control = &buf[0]
|
|
||||||
msg.Hdr.Controllen = uint64(len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|
||||||
h.Len = uint64(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setIovecSlice(iov *iovec, b []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
iov.Base = nil
|
|
||||||
iov.Len = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
iov.Base = &b[0]
|
|
||||||
iov.Len = uint64(len(b))
|
|
||||||
}
|
|
||||||
@@ -65,16 +65,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
ctrl.ShutdownBlock()
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
wait()
|
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -61,22 +58,10 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
wait()
|
ctrl.ShutdownBlock()
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
|
if crt == nil {
|
||||||
|
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||||
|
return true
|
||||||
|
}
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
|||||||
cm.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||||
// check and return true.
|
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false
|
return false //don't tear down tunnels for handshakes in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false //cert is still valid! yay!
|
||||||
}
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
|
|
||||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
return false
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
}
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||||
|
return true
|
||||||
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
} else {
|
||||||
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
myCrt := cs.getCertificate(curCrt.Version())
|
curCrtVersion := curCrt.Version()
|
||||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
if myCrt == nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("reason", "local certificate removed").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
peerCrt := hostinfo.ConnectionState.peerCert
|
||||||
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||||
|
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("reason", "local certificate is not current").
|
WithField("reason", "local certificate is not current").
|
||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -446,10 +446,6 @@ func (d *dummyCert) PublicKey() []byte {
|
|||||||
return d.publicKey
|
return d.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
|
||||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyCert) Signature() []byte {
|
func (d *dummyCert) Signature() []byte {
|
||||||
return d.signature
|
return d.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
|
const ReplayWindow = 1024
|
||||||
// 4092 should be sufficient for 5Gbps
|
|
||||||
const ReplayWindow = 8192
|
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
|
|||||||
56
control.go
56
control.go
@@ -2,11 +2,9 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -15,16 +13,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RunState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
Stopped RunState = 0 // The control has yet to be started
|
|
||||||
Started RunState = 1 // The control has been started
|
|
||||||
Stopping RunState = 2 // The control is stopping
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrAlreadyStarted = errors.New("nebula is already started")
|
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -38,9 +26,6 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
stateLock sync.Mutex
|
|
||||||
state RunState
|
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -64,21 +49,10 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call.
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
// The returned function can be used to wait for nebula to fully stop.
|
func (c *Control) Start() {
|
||||||
func (c *Control) Start() (func(), error) {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Stopped {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, ErrAlreadyStarted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
err := c.f.activate()
|
c.f.activate()
|
||||||
if err != nil {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -98,33 +72,15 @@ func (c *Control) Start() (func(), error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.state = Started
|
c.f.run()
|
||||||
c.stateLock.Unlock()
|
|
||||||
return c.f.run(c.ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) State() RunState {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
defer c.stateLock.Unlock()
|
|
||||||
return c.state
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Started {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
// We are stopping or stopped already
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.state = Stopping
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -133,7 +89,7 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.state = Stopped
|
c.l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ type m = map[string]any
|
|||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
|
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
@@ -129,6 +146,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newServer creates a nebula instance with fewer assumptions
|
||||||
|
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
vpnNetworks := certs[len(certs)-1].Networks()
|
||||||
|
|
||||||
|
var udpAddr netip.AddrPort
|
||||||
|
if vpnNetworks[0].Addr().Is4() {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As4()
|
||||||
|
budpIp[1] -= 128
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||||
|
} else {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As16()
|
||||||
|
// beef for funsies
|
||||||
|
budpIp[2] = 190
|
||||||
|
budpIp[3] = 239
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
|
}
|
||||||
|
|
||||||
|
caStr := ""
|
||||||
|
for _, ca := range caCrt {
|
||||||
|
x, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr += string(x)
|
||||||
|
}
|
||||||
|
certStr := ""
|
||||||
|
for _, c := range certs {
|
||||||
|
x, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
certStr += string(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": certStr,
|
||||||
|
"key": string(key),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": m{
|
||||||
|
"outbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": m{
|
||||||
|
"host": udpAddr.Addr().String(),
|
||||||
|
"port": udpAddr.Port(),
|
||||||
|
},
|
||||||
|
"logging": m{
|
||||||
|
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||||
|
"level": l.Level.String(),
|
||||||
|
},
|
||||||
|
"timers": m{
|
||||||
|
"pending_deletion_interval": 2,
|
||||||
|
"connection_alive_interval": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if overrides != nil {
|
||||||
|
final := m{}
|
||||||
|
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mc = final
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
cStr := string(cb)
|
||||||
|
c.LoadString(cStr)
|
||||||
|
|
||||||
|
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return control, vpnNetworks, udpAddr, c
|
||||||
|
}
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
|
|||||||
@@ -4,12 +4,16 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
@@ -55,3 +59,309 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertUpgrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCert2Pem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v2-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
r.Logf("version %d", version)
|
||||||
|
if version == cert.Version2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*10 {
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertDowngrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCertPem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v1-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == cert.Version1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("wtf")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertMismatchCorrection(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == theirVersion {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("wtf")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
|
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
//myVpnV4 := myVpnIpNet[0]
|
||||||
|
myVpnV6 := myVpnIpNet[1]
|
||||||
|
relayVpnV4 := relayVpnIpNet[0]
|
||||||
|
relayVpnV6 := relayVpnIpNet[1]
|
||||||
|
theirVpnV6 := theirVpnIpNet[0]
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
r.Log("Assert the tunnel works")
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("reply?")
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||||
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
|
//t.Log("finish up")
|
||||||
|
//myControl.Stop()
|
||||||
|
//theirControl.Stop()
|
||||||
|
//relayControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,15 +23,19 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we're connecting to a v6 address we must use a v2 cert
|
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
|
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||||
|
v = hh.initiatingVersionOverride
|
||||||
|
} else if v < cert.Version2 {
|
||||||
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
if a.Is6() {
|
if a.Is6() {
|
||||||
v = cert.Version2
|
v = cert.Version2
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
crt := cs.getCertificate(v)
|
crt := cs.getCertificate(v)
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
@@ -163,16 +167,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if rc == nil {
|
if myCertOtherVersion == nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
f.l.WithError(err).WithFields(m{
|
||||||
Info("Unable to handshake with host due to missing certificate version")
|
"udpAddr": addr,
|
||||||
return
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"cert": remoteCert,
|
||||||
|
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// Record the certificate we are actually using
|
// Record the certificate we are actually using
|
||||||
ci.myCert = rc
|
ci.myCert = myCertOtherVersion
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ type HandshakeHostInfo struct {
|
|||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
|
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||||
counter int64 // How many attempts have we made so far
|
counter int64 // How many attempts have we made so far
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
@@ -299,6 +300,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
InitiatorRelayIndex: idx,
|
InitiatorRelayIndex: idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayFrom := hm.f.myVpnAddrs[0]
|
||||||
|
|
||||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !hm.f.myVpnAddrs[0].Is4() {
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
@@ -316,7 +319,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
b = vpnIp.As4()
|
b = vpnIp.As4()
|
||||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
if vpnIp.Is4() {
|
||||||
|
relayFrom = hm.f.myVpnAddrs[0]
|
||||||
|
} else {
|
||||||
|
//todo do this smarter
|
||||||
|
relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
|
||||||
|
}
|
||||||
|
m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
default:
|
default:
|
||||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
@@ -331,7 +340,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
} else {
|
} else {
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom": relayFrom,
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": idx,
|
"initiatorRelayIndex": idx,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
@@ -357,6 +366,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayFrom := hm.f.myVpnAddrs[0]
|
||||||
|
|
||||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !hm.f.myVpnAddrs[0].Is4() {
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
@@ -374,7 +385,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
b = vpnIp.As4()
|
b = vpnIp.As4()
|
||||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
if vpnIp.Is4() {
|
||||||
|
relayFrom = hm.f.myVpnAddrs[0]
|
||||||
|
} else {
|
||||||
|
//todo do this smarter
|
||||||
|
relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
default:
|
default:
|
||||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
@@ -389,7 +407,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom": relayFrom,
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -521,6 +522,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
|
|||||||
return nil, nil, errors.New("unable to find host")
|
return nil, nil, errors.New("unable to find host")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastH := h
|
||||||
for h != nil {
|
for h != nil {
|
||||||
for _, targetIp := range targetIps {
|
for _, targetIp := range targetIps {
|
||||||
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
||||||
@@ -528,10 +530,12 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
|
|||||||
return h, r, nil
|
return h, r, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
lastH = h
|
||||||
h = h.next
|
h = h.next
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, errors.New("unable to find host with relay")
|
//todo no merge
|
||||||
|
return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
||||||
|
|||||||
156
interface.go
156
interface.go
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,17 +87,12 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
pktPool *packet.Pool
|
|
||||||
inbound chan *packet.Packet
|
|
||||||
outbound chan *packet.Packet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -200,15 +194,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
|
|
||||||
//TODO: configurable size
|
|
||||||
inbound: make(chan *packet.Packet, 2048),
|
|
||||||
outbound: make(chan *packet.Packet, 2048),
|
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.pktPool = packet.GetPool()
|
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
@@ -221,7 +209,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() error {
|
func (f *Interface) activate() {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -242,46 +230,33 @@ func (f *Interface) activate() error {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run(c context.Context) (func(), error) {
|
func (f *Interface) run() {
|
||||||
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
// read packets from udp and queue to f.inbound
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
|
|
||||||
// Launch n queues to read packets from inside tun dev and queue to f.outbound
|
|
||||||
//todo this never stops f.wg.Add(1)
|
|
||||||
go f.listenIn(f.readers[i], i)
|
|
||||||
|
|
||||||
// Launch n workers to process traffic from f.inbound and smash it onto the inside of the tun
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.workerIn(i, c)
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.workerIn(i, c)
|
|
||||||
|
|
||||||
// read from f.outbound and write to UDP (outside the tun)
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.workerOut(i, c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.wg.Wait, nil
|
// Launch n queues to read packets from tun dev
|
||||||
|
for i := 0; i < f.routines; i++ {
|
||||||
|
go f.listenIn(f.readers[i], i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -289,90 +264,41 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
err := li.ListenOut(f.pktPool.Get, f.inbound)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
if err != nil && !f.closed.Load() {
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
plaintext := make([]byte, udp.MTU)
|
||||||
//TODO: Trigger Control to close
|
h := &header.H{}
|
||||||
}
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
f.l.Debugf("underlay reader %v is done", i)
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.wg.Done()
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
packet := make([]byte, mtu)
|
||||||
|
out := make([]byte, mtu)
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
p := f.pktPool.Get()
|
n, err := reader.Read(packet)
|
||||||
n, err := reader.Read(p.Payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
|
||||||
//TODO: Trigger Control to close
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Payload = (p.Payload)[:n]
|
|
||||||
//TODO: nonblocking channel write
|
|
||||||
f.outbound <- p
|
|
||||||
//select {
|
|
||||||
//case f.outbound <- p:
|
|
||||||
//default:
|
|
||||||
// f.l.Error("Dropped packet from outbound channel")
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|
||||||
f.l.Debugf("overlay reader %v is done", i)
|
|
||||||
f.wg.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) workerIn(i int, ctx context.Context) {
|
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
fwPacket2 := &firewall.Packet{}
|
|
||||||
nb2 := make([]byte, 12, 12)
|
|
||||||
result2 := make([]byte, mtu)
|
|
||||||
h := &header.H{}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case p := <-f.inbound:
|
|
||||||
if p.SegSize > 0 && p.SegSize < len(p.Payload) {
|
|
||||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
|
||||||
end := offset + p.SegSize
|
|
||||||
if end > len(p.Payload) {
|
|
||||||
end = len(p.Payload)
|
|
||||||
}
|
|
||||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
|
||||||
}
|
|
||||||
|
|
||||||
f.pktPool.Put(p)
|
|
||||||
case <-ctx.Done():
|
|
||||||
f.wg.Done()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) workerOut(i int, ctx context.Context) {
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
fwPacket1 := &firewall.Packet{}
|
os.Exit(2)
|
||||||
nb1 := make([]byte, 12, 12)
|
|
||||||
result1 := make([]byte, mtu)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case data := <-f.outbound:
|
|
||||||
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
|
|
||||||
f.pktPool.Put(data)
|
|
||||||
case <-ctx.Done():
|
|
||||||
f.wg.Done()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -525,7 +451,6 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
// Release the udp readers
|
|
||||||
for _, u := range f.writers {
|
for _, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -533,13 +458,6 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun readers
|
// Release the tun device
|
||||||
for _, u := range f.readers {
|
return f.inside.Close()
|
||||||
err := u.Close()
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while closing tun device")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
18
main.go
18
main.go
@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
f: ifce,
|
ifce,
|
||||||
l: l,
|
l,
|
||||||
ctx: ctx,
|
ctx,
|
||||||
cancel: cancel,
|
cancel,
|
||||||
sshStart: sshStart,
|
sshStart,
|
||||||
statsStart: statsStart,
|
statsStart,
|
||||||
dnsStart: dnsStart,
|
dnsStart,
|
||||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
lightHouse.StartUpdateWorker,
|
||||||
connectionManagerStart: connManager.Start,
|
connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//f.l.Error("in packet ", h)
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -245,7 +245,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
@@ -471,7 +470,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -71,13 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
pLen := 128
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
pLen = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -553,3 +554,13 @@ func (t *tun) Name() string {
|
|||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
pLen := 128
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
pLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
@@ -20,18 +22,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||||
FIODGNAME = 0x80106678
|
FIODGNAME = 0x80106678
|
||||||
TUNSIFMODE = 0x8004745e
|
|
||||||
TUNSIFHEAD = 0x80047460
|
|
||||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
|
||||||
IN6_IFF_NODAD = 0x0020
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fiodgnameArg struct {
|
type fiodgnameArg struct {
|
||||||
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
type ifreqDestroy struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
pad [16]byte
|
pad [16]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Flags uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
MTU int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
devFd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
io.ReadWriteCloser
|
||||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := make([]byte, 4)
|
|
||||||
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, err
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head []byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.devFd >= 0 {
|
if t.ReadWriteCloser != nil {
|
||||||
err := syscall.Close(t.devFd)
|
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
t.l.WithError(err).Error("Error closing device")
|
|
||||||
}
|
}
|
||||||
t.devFd = -1
|
|
||||||
|
|
||||||
c := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
|
||||||
defer close(c)
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
if err == nil {
|
|
||||||
defer syscall.Close(s)
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
return err
|
||||||
}
|
}
|
||||||
}()
|
defer syscall.Close(s)
|
||||||
|
|
||||||
// wait up to 1 second so we start blocking at the ioctl
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
select {
|
|
||||||
case <-c:
|
// Destroy the interface
|
||||||
case <-time.After(1 * time.Second):
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
}
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
|||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var fd int
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the name of the interface
|
rawConn, err := file.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var name [16]byte
|
var name [16]byte
|
||||||
|
var ctrlErr error
|
||||||
|
rawConn.Control(func(fd uintptr) {
|
||||||
|
// Read the name of the interface
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||||
|
})
|
||||||
if ctrlErr == nil {
|
|
||||||
// set broadcast mode and multicast
|
|
||||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr == nil {
|
|
||||||
// turn on link-layer mode, to support ipv6
|
|
||||||
ifhead := uint32(1)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr != nil {
|
if ctrlErr != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if ifName != deviceName {
|
if ifName != deviceName {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
s, err := syscall.Socket(
|
||||||
|
syscall.AF_INET,
|
||||||
|
syscall.SOCK_DGRAM,
|
||||||
|
syscall.IPPROTO_IP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
|
ReadWriteCloser: file,
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
devFd: fd,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
if cidr.Addr().Is4() {
|
var err error
|
||||||
ifr := ifreqAlias4{
|
// TODO use syscalls instead of exec.Command
|
||||||
Name: t.deviceBytes(),
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
Addr: unix.RawSockaddrInet4{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Len: unix.SizeofSockaddrInet4,
|
if err = cmd.Run(); err != nil {
|
||||||
Family: unix.AF_INET,
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
},
|
|
||||||
DstAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: getBroadcast(cidr).As4(),
|
|
||||||
},
|
|
||||||
MaskAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
},
|
|
||||||
VHid: 0,
|
|
||||||
}
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||||
ifr := ifreqAlias6{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Name: t.deviceBytes(),
|
if err = cmd.Run(); err != nil {
|
||||||
Addr: unix.RawSockaddrInet6{
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
},
|
|
||||||
PrefixMask: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
},
|
|
||||||
Lifetime: addrLifetime{
|
|
||||||
Expire: 0,
|
|
||||||
Preferred: 0,
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
},
|
|
||||||
Flags: IN6_IFF_NODAD,
|
|
||||||
}
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
// Setup our default MTU
|
|
||||||
err := t.setMTU()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if linkAddr == nil {
|
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
|
||||||
}
|
|
||||||
t.linkAddr = linkAddr
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) setMTU() error {
|
|
||||||
// Set the MTU on the device
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
|
||||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -462,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -500,144 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func flipBytes(b []byte) []byte {
|
|
||||||
for i := 0; i < len(b); i++ {
|
|
||||||
b[i] ^= 0xFF
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func orBytes(a []byte, b []byte) []byte {
|
|
||||||
ret := make([]byte, len(a))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
ret[i] = a[i] | b[i]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
|
||||||
broadcast, _ := netip.AddrFromSlice(
|
|
||||||
orBytes(
|
|
||||||
cidr.Addr().AsSlice(),
|
|
||||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_ADD,
|
|
||||||
Flags: unix.RTF_UP,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
fmt.Println("DOING CHANGE")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLinkAddr Gets the link address for the interface of the given name
|
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range msgs {
|
|
||||||
switch m := m.(type) {
|
|
||||||
case *netroute.InterfaceMessage:
|
|
||||||
if m.Name == name {
|
|
||||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
|
||||||
if ok {
|
|
||||||
return sa, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const Size = 0xffff
|
|
||||||
|
|
||||||
type Packet struct {
|
|
||||||
Payload []byte
|
|
||||||
Control []byte
|
|
||||||
SegSize int
|
|
||||||
Addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *Packet {
|
|
||||||
return &Packet{
|
|
||||||
Payload: make([]byte, Size),
|
|
||||||
Control: make([]byte, unix.CmsgSpace(2)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Pool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigPool = &Pool{
|
|
||||||
pool: sync.Pool{New: func() any { return New() }},
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPool() *Pool {
|
|
||||||
return bigPool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Get() *Packet {
|
|
||||||
return p.pool.Get().(*Packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Put(x *Packet) {
|
|
||||||
x.Payload = x.Payload[:Size]
|
|
||||||
p.pool.Put(x)
|
|
||||||
}
|
|
||||||
39
pki.go
39
pki.go
@@ -100,41 +100,36 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new cert was different from old",
|
"Curve in new v1 cert was different from old",
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else if currentState.v1Cert != nil {
|
|
||||||
//TODO: CERT-V2 we should be able to tear this down
|
|
||||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -142,13 +137,25 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new cert was different from old",
|
"Curve in new cert was different from old",
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
} else if currentState.v2Cert != nil {
|
||||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||||
|
if newState.v1Cert == nil {
|
||||||
|
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||||
|
}
|
||||||
|
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||||
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||||
|
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayFrom := h.vpnAddrs[0]
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
peer := peerHostInfo.vpnAddrs[0]
|
peer := peerHostInfo.vpnAddrs[0]
|
||||||
if !peer.Is4() {
|
if !peer.Is4() {
|
||||||
@@ -207,7 +208,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
b = targetAddr.As4()
|
b = targetAddr.As4()
|
||||||
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
|
if targetAddr.Is4() {
|
||||||
|
relayFrom = h.vpnAddrs[0]
|
||||||
|
} else {
|
||||||
|
//todo do this smarter
|
||||||
|
relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
|
||||||
|
}
|
||||||
|
resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
resp.RelayToAddr = target
|
resp.RelayToAddr = target
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,7 +367,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: index,
|
InitiatorRelayIndex: index,
|
||||||
}
|
}
|
||||||
|
relayFrom := h.vpnAddrs[0]
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !h.vpnAddrs[0].Is4() {
|
if !h.vpnAddrs[0].Is4() {
|
||||||
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
||||||
@@ -377,7 +384,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
b = target.As4()
|
b = target.As4()
|
||||||
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
|
if target.Is4() {
|
||||||
|
relayFrom = h.vpnAddrs[0]
|
||||||
|
} else {
|
||||||
|
//todo do this smarter
|
||||||
|
relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
|
||||||
|
}
|
||||||
|
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
req.RelayToAddr = netAddrToProtoAddr(target)
|
req.RelayToAddr = netAddrToProtoAddr(target)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,7 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": h.vpnAddrs[0],
|
"relayFrom": relayFrom,
|
||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex": req.ResponderRelayIndex,
|
||||||
|
|||||||
@@ -9,13 +9,10 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
@@ -46,19 +43,8 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *config.C) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
logger := logrus.New()
|
control.Start()
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wait, err := control.Start()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -155,12 +141,6 @@ func New(config *config.C) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the nebula wait function to the group
|
|
||||||
eg.Go(func() error {
|
|
||||||
wait()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
10
udp/conn.go
10
udp/conn.go
@@ -4,19 +4,19 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(*packet.Packet)
|
type EncReader func(
|
||||||
|
addr netip.AddrPort,
|
||||||
type PacketBufferGetter func() *packet.Packet
|
payload []byte,
|
||||||
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
|
|||||||
@@ -71,14 +71,15 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
func (u *GenericConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
|||||||
@@ -9,26 +9,27 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
enableGRO bool
|
}
|
||||||
enableGSO bool
|
|
||||||
//gso gsoState
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
|
ip4 := ip.To4()
|
||||||
|
if ip4 != nil {
|
||||||
|
return ip4, true
|
||||||
|
}
|
||||||
|
return ip, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -54,11 +55,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a read timeout
|
|
||||||
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
var sa unix.Sockaddr
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
sa4 := &unix.SockaddrInet4{Port: port}
|
sa4 := &unix.SockaddrInet4{Port: port}
|
||||||
@@ -122,10 +118,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
|
||||||
msgs, packets, names := u.PrepareRawMessages(u.batch, pg)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
@@ -134,61 +130,22 @@ func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
|
|||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
out := packets[i]
|
|
||||||
out.Payload = out.Payload[:msgs[i].Len]
|
|
||||||
|
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
ctrlLen := getRawMessageControlLen(&msgs[i])
|
|
||||||
if ctrlLen > 0 {
|
|
||||||
packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen])
|
|
||||||
} else {
|
|
||||||
packets[i].SegSize = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
pc <- out
|
|
||||||
|
|
||||||
//rotate this packet out so we don't overwrite it
|
|
||||||
packets[i] = pg()
|
|
||||||
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
|
|
||||||
if u.enableGRO {
|
|
||||||
msgs[i].Hdr.Control = &packets[i].Control[0]
|
|
||||||
msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGROControl(control []byte) int {
|
|
||||||
if len(control) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(control)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
|
|
||||||
return segSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
@@ -202,9 +159,6 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,9 +180,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
|||||||
|
|
||||||
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||||
if !ip.Addr().Is4() {
|
if !ip.Addr().Is4() {
|
||||||
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
return ErrInvalidIPv6RemoteForSocket
|
||||||
}
|
}
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet4
|
var rsa unix.RawSockaddrInet4
|
||||||
@@ -343,28 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
u.configureGRO(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGRO(enable bool) {
|
|
||||||
if enable == u.enableGRO {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.enableGRO = true
|
|
||||||
u.l.Info("UDP GRO enabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
|
||||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
||||||
}
|
|
||||||
u.enableGRO = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,39 +33,17 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
if len(buf) == 0 {
|
|
||||||
msg.Hdr.Control = nil
|
|
||||||
msg.Hdr.Controllen = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Control = &buf[0]
|
|
||||||
msg.Hdr.Controllen = uint64(len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageControlLen(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Controllen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|
||||||
h.Len = uint64(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
packets := make([]*packet.Packet, n)
|
|
||||||
for i := range packets {
|
|
||||||
packets[i] = pg()
|
|
||||||
}
|
|
||||||
//todo?
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
|
buffers[i] = make([]byte, MTU)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
|
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
@@ -74,14 +51,7 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
if u.enableGRO {
|
|
||||||
msgs[i].Hdr.Control = &packets[i].Control[0]
|
|
||||||
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, packets, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
func (u *RIOConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
Reference in New Issue
Block a user