mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
3 Commits
channels-s
...
remove-old
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c745e8cfe | ||
|
|
87a4ec7d90 | ||
|
|
47d4055e10 |
2
.github/ISSUE_TEMPLATE/config.yml
vendored
2
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -17,5 +17,5 @@ contact_links:
|
|||||||
about: 'The documentation is the best place to start if you are new to Nebula.'
|
about: 'The documentation is the best place to start if you are new to Nebula.'
|
||||||
|
|
||||||
- name: 💁 Support/Chat
|
- name: 💁 Support/Chat
|
||||||
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
|
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ
|
||||||
about: 'For faster support, join us on Slack for assistance!'
|
about: 'For faster support, join us on Slack for assistance!'
|
||||||
|
|||||||
4
.github/workflows/gofmt.yml
vendored
4
.github/workflows/gofmt.yml
vendored
@@ -16,9 +16,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Install goimports
|
- name: Install goimports
|
||||||
|
|||||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -12,9 +12,9 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -35,9 +35,9 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -68,9 +68,9 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Import certificates
|
- name: Import certificates
|
||||||
|
|||||||
4
.github/workflows/smoke-extra.yml
vendored
4
.github/workflows/smoke-extra.yml
vendored
@@ -22,9 +22,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version-file: 'go.mod'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: add hashicorp source
|
- name: add hashicorp source
|
||||||
|
|||||||
4
.github/workflows/smoke.yml
vendored
4
.github/workflows/smoke.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: build
|
- name: build
|
||||||
|
|||||||
24
.github/workflows/test.yml
vendored
24
.github/workflows/test.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -32,9 +32,9 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v8
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.0
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
@@ -58,9 +58,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -79,9 +79,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -100,9 +100,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build nebula
|
- name: Build nebula
|
||||||
@@ -115,9 +115,9 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v8
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.0
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
|||||||
|
|
||||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||||
|
|
||||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
|
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
|
||||||
|
|
||||||
## Supported Platforms
|
## Supported Platforms
|
||||||
|
|
||||||
|
|||||||
3
bits.go
3
bits.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
|
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
length uint64
|
length uint64
|
||||||
current uint64
|
current uint64
|
||||||
@@ -44,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Error("rejected a packet (top) %d %d\n", b.current, i)
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
|||||||
|
|
||||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||||
|
|
||||||
rawMap, ok := value.(map[string]any)
|
rawMap, ok := value.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
}
|
}
|
||||||
for rawCIDR, rawValue := range rawMap {
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||||
rawMap, ok := raw.(map[string]any)
|
rawMap, ok := raw.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,9 +58,6 @@ type Certificate interface {
|
|||||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||||
PublicKey() []byte
|
PublicKey() []byte
|
||||||
|
|
||||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
|
||||||
MarshalPublicKeyPEM() []byte
|
|
||||||
|
|
||||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||||
Curve() Curve
|
Curve() Curve
|
||||||
|
|
||||||
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
|||||||
case Version2:
|
case Version2:
|
||||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||||
default:
|
default:
|
||||||
return nil, ErrUnknownVersion
|
//TODO: CERT-V2 make a static var
|
||||||
|
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
|
|||||||
return c.details.publicKey
|
return c.details.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV1) Signature() []byte {
|
func (c *certificateV1) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
if err != nil {
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
return false
|
|
||||||
}
|
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,7 +13,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Marshal(t *testing.T) {
|
func TestCertificateV1_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV1{
|
|
||||||
details: detailsV1{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
publicKey: pubKey,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version1, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.details.curve = Curve_P256
|
|
||||||
nc.details.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV1_Expired(t *testing.T) {
|
func TestCertificateV1_Expired(t *testing.T) {
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
|
|||||||
@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
|
|||||||
return c.publicKey
|
return c.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV2) Signature() []byte {
|
func (c *certificateV2) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
if err != nil {
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
return false
|
|
||||||
}
|
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV2_Marshal(t *testing.T) {
|
func TestCertificateV2_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV2{
|
|
||||||
details: detailsV2{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
publicKey: pubKey,
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version2, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.curve = Curve_P256
|
|
||||||
nc.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV2_Expired(t *testing.T) {
|
func TestCertificateV2_Expired(t *testing.T) {
|
||||||
nc := certificateV2{
|
nc := certificateV2{
|
||||||
details: detailsV2{
|
details: detailsV2{
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ var (
|
|||||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
|
||||||
|
|
||||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||||
|
|||||||
141
cert/pem.go
141
cert/pem.go
@@ -1,34 +1,25 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ( //cert banners
|
const (
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
)
|
|
||||||
|
|
||||||
const ( //key-agreement-key banners
|
|
||||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
|
||||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
|
||||||
const ( //signing key banners
|
|
||||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
|
||||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
|
||||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
|
||||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||||
|
|
||||||
|
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||||
|
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||||
|
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||||
|
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||||
@@ -60,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
|
||||||
if c.IsCA() {
|
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
} else {
|
|
||||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
switch curve {
|
switch curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
@@ -81,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|
||||||
switch curve {
|
|
||||||
case Curve_CURVE25519:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
|
||||||
case Curve_P256:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
k, r := pem.Decode(b)
|
k, r := pem.Decode(b)
|
||||||
if k == nil {
|
if k == nil {
|
||||||
@@ -105,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
|||||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||||
expectedLen = 32
|
expectedLen = 32
|
||||||
curve = Curve_CURVE25519
|
curve = Curve_CURVE25519
|
||||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
case P256PublicKeyBanner:
|
||||||
// Uncompressed
|
// Uncompressed
|
||||||
expectedLen = 65
|
expectedLen = 65
|
||||||
curve = Curve_P256
|
curve = Curve_P256
|
||||||
@@ -140,101 +108,6 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward compatibility functions for older API
|
|
||||||
func MarshalX25519PublicKey(b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalX25519PrivateKey(b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificate is a compatibility wrapper for the old API
|
|
||||||
type NebulaCertificate struct {
|
|
||||||
Details NebulaCertificateDetails
|
|
||||||
Signature []byte
|
|
||||||
cert Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
|
||||||
type NebulaCertificateDetails struct {
|
|
||||||
Name string
|
|
||||||
NotBefore time.Time
|
|
||||||
NotAfter time.Time
|
|
||||||
PublicKey []byte
|
|
||||||
IsCA bool
|
|
||||||
Issuer []byte
|
|
||||||
Curve Curve
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
|
||||||
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
|
||||||
c, rest, err := UnmarshalCertificateFromPEM(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, err
|
|
||||||
}
|
|
||||||
|
|
||||||
issuerBytes, err := func() ([]byte, error) {
|
|
||||||
issuer := c.Issuer()
|
|
||||||
if issuer == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
decoded, err := hex.DecodeString(issuer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
|
||||||
}
|
|
||||||
return decoded, nil
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := c.PublicKey()
|
|
||||||
if pubKey != nil {
|
|
||||||
pubKey = append([]byte(nil), pubKey...)
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := c.Signature()
|
|
||||||
if sig != nil {
|
|
||||||
sig = append([]byte(nil), sig...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &NebulaCertificate{
|
|
||||||
Details: NebulaCertificateDetails{
|
|
||||||
Name: c.Name(),
|
|
||||||
NotBefore: c.NotBefore(),
|
|
||||||
NotAfter: c.NotAfter(),
|
|
||||||
PublicKey: pubKey,
|
|
||||||
IsCA: c.IsCA(),
|
|
||||||
Issuer: issuerBytes,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
},
|
|
||||||
Signature: sig,
|
|
||||||
cert: c,
|
|
||||||
}, rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IssuerString returns the issuer in hex format for compatibility
|
|
||||||
func (n *NebulaCertificate) IssuerString() string {
|
|
||||||
if n.Details.Issuer == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(n.Details.Issuer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Certificate returns the underlying certificate (read-only)
|
|
||||||
func (n *NebulaCertificate) Certificate() Certificate {
|
|
||||||
return n.cert
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||||
// consumed data or an error on failure
|
// consumed data or an error on failure
|
||||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
`)
|
|
||||||
oldPubP256Key := []byte(`# A good key
|
|
||||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
`)
|
`)
|
||||||
shortKey := []byte(`# A short key
|
shortKey := []byte(`# A short key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||||
|
|
||||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
|
|
||||||
// Success test case
|
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
|
||||||
assert.Len(t, k, 65)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
|
||||||
assert.Equal(t, Curve_P256, curve)
|
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 65)
|
assert.Len(t, k, 65)
|
||||||
|
|||||||
12
cert/sign.go
12
cert/sign.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
|||||||
}
|
}
|
||||||
return t.SignWith(signer, curve, sp)
|
return t.SignWith(signer, curve, sp)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
pk := &ecdsa.PrivateKey{
|
||||||
if err != nil {
|
PublicKey: ecdsa.PublicKey{
|
||||||
return nil, err
|
Curve: elliptic.P256(),
|
||||||
|
},
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||||
|
D: new(big.Int).SetBytes(key),
|
||||||
}
|
}
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||||
|
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||||
sp := func(certBytes []byte) ([]byte, error) {
|
sp := func(certBytes []byte) ([]byte, error) {
|
||||||
// We need to hash first for ECDSA
|
// We need to hash first for ECDSA
|
||||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||||
|
|||||||
@@ -65,16 +65,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
ctrl.ShutdownBlock()
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
wait()
|
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -61,22 +58,10 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
wait()
|
ctrl.ShutdownBlock()
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -4,10 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -29,127 +27,311 @@ const (
|
|||||||
sendTestPacket trafficDecision = 6
|
sendTestPacket trafficDecision = 6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LastCommunication tracks when we last communicated with a host
|
||||||
|
type LastCommunication struct {
|
||||||
|
timestamp time.Time
|
||||||
|
vpnIp netip.Addr // To help with logging
|
||||||
|
}
|
||||||
|
|
||||||
type connectionManager struct {
|
type connectionManager struct {
|
||||||
|
in map[uint32]struct{}
|
||||||
|
inLock *sync.RWMutex
|
||||||
|
|
||||||
|
out map[uint32]struct{}
|
||||||
|
outLock *sync.RWMutex
|
||||||
|
|
||||||
// relayUsed holds which relay localIndexs are in use
|
// relayUsed holds which relay localIndexs are in use
|
||||||
relayUsed map[uint32]struct{}
|
relayUsed map[uint32]struct{}
|
||||||
relayUsedLock *sync.RWMutex
|
relayUsedLock *sync.RWMutex
|
||||||
|
|
||||||
|
// Track last communication with hosts
|
||||||
|
lastCommMap map[uint32]time.Time
|
||||||
|
lastCommLock *sync.RWMutex
|
||||||
|
inactivityTimer *LockingTimerWheel[uint32]
|
||||||
|
inactivityTimeout time.Duration
|
||||||
|
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
trafficTimer *LockingTimerWheel[uint32]
|
trafficTimer *LockingTimerWheel[uint32]
|
||||||
intf *Interface
|
intf *Interface
|
||||||
|
pendingDeletion map[uint32]struct{}
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
// Configuration settings
|
|
||||||
checkInterval time.Duration
|
checkInterval time.Duration
|
||||||
pendingDeletionInterval time.Duration
|
pendingDeletionInterval time.Duration
|
||||||
inactivityTimeout atomic.Int64
|
|
||||||
dropInactive atomic.Bool
|
|
||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
metricsTxPunchy metrics.Counter
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
|
||||||
cm := &connectionManager{
|
var max time.Duration
|
||||||
hostMap: hm,
|
if checkInterval < pendingDeletionInterval {
|
||||||
l: l,
|
max = pendingDeletionInterval
|
||||||
punchy: p,
|
} else {
|
||||||
|
max = checkInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
nc := &connectionManager{
|
||||||
|
hostMap: intf.hostMap,
|
||||||
|
in: make(map[uint32]struct{}),
|
||||||
|
inLock: &sync.RWMutex{},
|
||||||
|
out: make(map[uint32]struct{}),
|
||||||
|
outLock: &sync.RWMutex{},
|
||||||
relayUsed: make(map[uint32]struct{}),
|
relayUsed: make(map[uint32]struct{}),
|
||||||
relayUsedLock: &sync.RWMutex{},
|
relayUsedLock: &sync.RWMutex{},
|
||||||
|
lastCommMap: make(map[uint32]time.Time),
|
||||||
|
lastCommLock: &sync.RWMutex{},
|
||||||
|
inactivityTimeout: 1 * time.Minute, // Default inactivity timeout: 10 minutes
|
||||||
|
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
|
||||||
|
intf: intf,
|
||||||
|
pendingDeletion: make(map[uint32]struct{}),
|
||||||
|
checkInterval: checkInterval,
|
||||||
|
pendingDeletionInterval: pendingDeletionInterval,
|
||||||
|
punchy: punchy,
|
||||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.reload(c, true)
|
// Initialize the inactivity timer wheel - make wheel duration slightly longer than the timeout
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
nc.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, nc.inactivityTimeout+time.Minute)
|
||||||
cm.reload(c, false)
|
|
||||||
})
|
|
||||||
|
|
||||||
return cm
|
nc.Start(ctx)
|
||||||
|
return nc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) reload(c *config.C, initial bool) {
|
func (n *connectionManager) updateLastCommunication(localIndex uint32) {
|
||||||
if initial {
|
// Get host info to record VPN IP for better logging
|
||||||
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
|
hostInfo := n.hostMap.QueryIndex(localIndex)
|
||||||
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
|
if hostInfo == nil {
|
||||||
|
|
||||||
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
|
|
||||||
// pretty close to their configured duration.
|
|
||||||
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
|
|
||||||
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
|
|
||||||
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
|
|
||||||
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
|
|
||||||
}
|
|
||||||
|
|
||||||
if initial || c.HasChanged("tunnels.inactivity_timeout") {
|
|
||||||
old := cm.getInactivityTimeout()
|
|
||||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
|
||||||
if !initial {
|
|
||||||
cm.l.WithField("oldDuration", old).
|
|
||||||
WithField("newDuration", cm.getInactivityTimeout()).
|
|
||||||
Info("Inactivity timeout has changed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if initial || c.HasChanged("tunnels.drop_inactive") {
|
|
||||||
old := cm.dropInactive.Load()
|
|
||||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
|
||||||
if !initial {
|
|
||||||
cm.l.WithField("oldBool", old).
|
|
||||||
WithField("newBool", cm.dropInactive.Load()).
|
|
||||||
Info("Drop inactive setting has changed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) getInactivityTimeout() time.Duration {
|
|
||||||
return (time.Duration)(cm.inactivityTimeout.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) In(h *HostInfo) {
|
|
||||||
h.in.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) Out(h *HostInfo) {
|
|
||||||
h.out.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) RelayUsed(localIndex uint32) {
|
|
||||||
cm.relayUsedLock.RLock()
|
|
||||||
// If this already exists, return
|
|
||||||
if _, ok := cm.relayUsed[localIndex]; ok {
|
|
||||||
cm.relayUsedLock.RUnlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cm.relayUsedLock.RUnlock()
|
|
||||||
cm.relayUsedLock.Lock()
|
now := time.Now()
|
||||||
cm.relayUsed[localIndex] = struct{}{}
|
n.lastCommLock.Lock()
|
||||||
cm.relayUsedLock.Unlock()
|
n.lastCommMap[localIndex] = now
|
||||||
|
n.lastCommLock.Unlock()
|
||||||
|
|
||||||
|
// Reset the inactivity timer for this host
|
||||||
|
n.inactivityTimer.m.Lock()
|
||||||
|
n.inactivityTimer.t.Add(localIndex, n.inactivityTimeout)
|
||||||
|
n.inactivityTimer.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) In(localIndex uint32) {
|
||||||
|
n.inLock.RLock()
|
||||||
|
// If this already exists, return
|
||||||
|
if _, ok := n.in[localIndex]; ok {
|
||||||
|
n.inLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.inLock.RUnlock()
|
||||||
|
n.inLock.Lock()
|
||||||
|
n.in[localIndex] = struct{}{}
|
||||||
|
n.inLock.Unlock()
|
||||||
|
|
||||||
|
// Update last communication time
|
||||||
|
n.updateLastCommunication(localIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) Out(localIndex uint32) {
|
||||||
|
n.outLock.RLock()
|
||||||
|
// If this already exists, return
|
||||||
|
if _, ok := n.out[localIndex]; ok {
|
||||||
|
n.outLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.outLock.RUnlock()
|
||||||
|
n.outLock.Lock()
|
||||||
|
n.out[localIndex] = struct{}{}
|
||||||
|
n.outLock.Unlock()
|
||||||
|
|
||||||
|
// Update last communication time
|
||||||
|
n.updateLastCommunication(localIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) RelayUsed(localIndex uint32) {
|
||||||
|
n.relayUsedLock.RLock()
|
||||||
|
// If this already exists, return
|
||||||
|
if _, ok := n.relayUsed[localIndex]; ok {
|
||||||
|
n.relayUsedLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.relayUsedLock.RUnlock()
|
||||||
|
n.relayUsedLock.Lock()
|
||||||
|
n.relayUsed[localIndex] = struct{}{}
|
||||||
|
n.relayUsedLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
||||||
// resets the state for this local index
|
// resets the state for this local index
|
||||||
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
|
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
|
||||||
in := h.in.Swap(false)
|
n.inLock.Lock()
|
||||||
out := h.out.Swap(false)
|
n.outLock.Lock()
|
||||||
if in || out {
|
_, in := n.in[localIndex]
|
||||||
h.lastUsed = now
|
_, out := n.out[localIndex]
|
||||||
}
|
delete(n.in, localIndex)
|
||||||
|
delete(n.out, localIndex)
|
||||||
|
n.inLock.Unlock()
|
||||||
|
n.outLock.Unlock()
|
||||||
return in, out
|
return in, out
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTrafficWatch must be called for every new HostInfo.
|
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
|
||||||
// We will continue to monitor the HostInfo until the tunnel is dropped.
|
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
|
||||||
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
|
n.outLock.Lock()
|
||||||
if h.out.Swap(true) == false {
|
if _, ok := n.out[localIndex]; ok {
|
||||||
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
|
n.outLock.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.out[localIndex] = struct{}{}
|
||||||
|
n.trafficTimer.Add(localIndex, n.checkInterval)
|
||||||
|
n.outLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkInactiveTunnels checks for tunnels that have been inactive for too long and drops them
|
||||||
|
func (n *connectionManager) checkInactiveTunnels() {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// First, advance the timer wheel to the current time
|
||||||
|
n.inactivityTimer.m.Lock()
|
||||||
|
n.inactivityTimer.t.Advance(now)
|
||||||
|
n.inactivityTimer.m.Unlock()
|
||||||
|
|
||||||
|
// Check for expired timers (inactive connections)
|
||||||
|
for {
|
||||||
|
// Get the next expired tunnel
|
||||||
|
n.inactivityTimer.m.Lock()
|
||||||
|
localIndex, ok := n.inactivityTimer.t.Purge()
|
||||||
|
n.inactivityTimer.m.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
// No more expired timers
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
n.lastCommLock.RLock()
|
||||||
|
lastComm, exists := n.lastCommMap[localIndex]
|
||||||
|
n.lastCommLock.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
// No last communication record, odd but skip
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate inactivity duration
|
||||||
|
inactiveDuration := now.Sub(lastComm)
|
||||||
|
|
||||||
|
// Check if we've exceeded the inactivity timeout
|
||||||
|
if inactiveDuration >= n.inactivityTimeout {
|
||||||
|
// Get the host info (if it still exists)
|
||||||
|
hostInfo := n.hostMap.QueryIndex(localIndex)
|
||||||
|
if hostInfo == nil {
|
||||||
|
// Host info is gone, remove from our tracking map
|
||||||
|
n.lastCommLock.Lock()
|
||||||
|
delete(n.lastCommMap, localIndex)
|
||||||
|
n.lastCommLock.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log the inactivity and drop the tunnel
|
||||||
|
n.l.WithField("vpnIp", hostInfo.vpnAddrs[0]).
|
||||||
|
WithField("localIndex", localIndex).
|
||||||
|
WithField("inactiveDuration", inactiveDuration).
|
||||||
|
WithField("timeout", n.inactivityTimeout).
|
||||||
|
Info("Dropping tunnel due to inactivity")
|
||||||
|
|
||||||
|
// Close the tunnel using the existing mechanism
|
||||||
|
n.intf.closeTunnel(hostInfo)
|
||||||
|
|
||||||
|
// Clean up our tracking map
|
||||||
|
n.lastCommLock.Lock()
|
||||||
|
delete(n.lastCommMap, localIndex)
|
||||||
|
n.lastCommLock.Unlock()
|
||||||
|
} else {
|
||||||
|
// Re-add to the timer wheel with the remaining time
|
||||||
|
remainingTime := n.inactivityTimeout - inactiveDuration
|
||||||
|
n.inactivityTimer.m.Lock()
|
||||||
|
n.inactivityTimer.t.Add(localIndex, remainingTime)
|
||||||
|
n.inactivityTimer.m.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) Start(ctx context.Context) {
|
// CleanupDeletedHostInfos removes entries from our lastCommMap for hosts that no longer exist
|
||||||
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
|
func (n *connectionManager) CleanupDeletedHostInfos() {
|
||||||
|
n.lastCommLock.Lock()
|
||||||
|
defer n.lastCommLock.Unlock()
|
||||||
|
|
||||||
|
// Find indexes to delete
|
||||||
|
var toDelete []uint32
|
||||||
|
for localIndex := range n.lastCommMap {
|
||||||
|
if n.hostMap.QueryIndex(localIndex) == nil {
|
||||||
|
toDelete = append(toDelete, localIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete them
|
||||||
|
for _, localIndex := range toDelete {
|
||||||
|
delete(n.lastCommMap, localIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toDelete) > 0 && n.l.Level >= logrus.DebugLevel {
|
||||||
|
n.l.WithField("count", len(toDelete)).Debug("Cleaned up deleted host entries from lastCommMap")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadConfig updates the connection manager configuration
|
||||||
|
func (n *connectionManager) ReloadConfig(c *config.C) {
|
||||||
|
// Get the inactivity timeout from config
|
||||||
|
inactivityTimeout := c.GetDuration("timers.inactivity_timeout", 10*time.Minute)
|
||||||
|
|
||||||
|
// Only update if different
|
||||||
|
if inactivityTimeout != n.inactivityTimeout {
|
||||||
|
n.l.WithField("old", n.inactivityTimeout).
|
||||||
|
WithField("new", inactivityTimeout).
|
||||||
|
Info("Updating inactivity timeout")
|
||||||
|
|
||||||
|
n.inactivityTimeout = inactivityTimeout
|
||||||
|
|
||||||
|
// Recreate the inactivity timer wheel with the new timeout
|
||||||
|
n.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, n.inactivityTimeout+time.Minute)
|
||||||
|
|
||||||
|
// Re-add all existing hosts to the new timer wheel
|
||||||
|
n.lastCommLock.RLock()
|
||||||
|
for localIndex, lastComm := range n.lastCommMap {
|
||||||
|
// Calculate remaining time based on last communication
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(lastComm)
|
||||||
|
|
||||||
|
// If the elapsed time exceeds the new timeout, this will be caught
|
||||||
|
// in the next inactivity check. Otherwise, add with remaining time.
|
||||||
|
if elapsed < n.inactivityTimeout {
|
||||||
|
remainingTime := n.inactivityTimeout - elapsed
|
||||||
|
n.inactivityTimer.m.Lock()
|
||||||
|
n.inactivityTimer.t.Add(localIndex, remainingTime)
|
||||||
|
n.inactivityTimer.m.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.lastCommLock.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) Start(ctx context.Context) {
|
||||||
|
go n.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) Run(ctx context.Context) {
|
||||||
|
//TODO: this tick should be based on the min wheel tick? Check firewall
|
||||||
|
clockSource := time.NewTicker(500 * time.Millisecond)
|
||||||
defer clockSource.Stop()
|
defer clockSource.Stop()
|
||||||
|
|
||||||
|
// Create ticker for inactivity checks (every minute)
|
||||||
|
inactivityTicker := time.NewTicker(time.Minute)
|
||||||
|
defer inactivityTicker.Stop()
|
||||||
|
|
||||||
|
// Create ticker for cleanup (every 5 minutes)
|
||||||
|
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer cleanupTicker.Stop()
|
||||||
|
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -160,61 +342,69 @@ func (cm *connectionManager) Start(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
case now := <-clockSource.C:
|
case now := <-clockSource.C:
|
||||||
cm.trafficTimer.Advance(now)
|
n.trafficTimer.Advance(now)
|
||||||
for {
|
for {
|
||||||
localIndex, has := cm.trafficTimer.Purge()
|
localIndex, has := n.trafficTimer.Purge()
|
||||||
if !has {
|
if !has {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.doTrafficCheck(localIndex, p, nb, out, now)
|
n.doTrafficCheck(localIndex, p, nb, out, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case <-inactivityTicker.C:
|
||||||
|
// Check for inactive tunnels
|
||||||
|
n.checkInactiveTunnels()
|
||||||
|
|
||||||
|
case <-cleanupTicker.C:
|
||||||
|
// Periodically clean up deleted hosts
|
||||||
|
n.CleanupDeletedHostInfos()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
||||||
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
|
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
|
||||||
|
|
||||||
switch decision {
|
switch decision {
|
||||||
case deleteTunnel:
|
case deleteTunnel:
|
||||||
if cm.hostMap.DeleteHostInfo(hostinfo) {
|
if n.hostMap.DeleteHostInfo(hostinfo) {
|
||||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||||
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
case closeTunnel:
|
case closeTunnel:
|
||||||
cm.intf.sendCloseTunnel(hostinfo)
|
n.intf.sendCloseTunnel(hostinfo)
|
||||||
cm.intf.closeTunnel(hostinfo)
|
n.intf.closeTunnel(hostinfo)
|
||||||
|
|
||||||
case swapPrimary:
|
case swapPrimary:
|
||||||
cm.swapPrimary(hostinfo, primary)
|
n.swapPrimary(hostinfo, primary)
|
||||||
|
|
||||||
case migrateRelays:
|
case migrateRelays:
|
||||||
cm.migrateRelayUsed(hostinfo, primary)
|
n.migrateRelayUsed(hostinfo, primary)
|
||||||
|
|
||||||
case tryRehandshake:
|
case tryRehandshake:
|
||||||
cm.tryRehandshake(hostinfo)
|
n.tryRehandshake(hostinfo)
|
||||||
|
|
||||||
case sendTestPacket:
|
case sendTestPacket:
|
||||||
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.resetRelayTrafficCheck(hostinfo)
|
n.resetRelayTrafficCheck(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
||||||
if hostinfo != nil {
|
if hostinfo != nil {
|
||||||
cm.relayUsedLock.Lock()
|
n.relayUsedLock.Lock()
|
||||||
defer cm.relayUsedLock.Unlock()
|
defer n.relayUsedLock.Unlock()
|
||||||
// No need to migrate any relays, delete usage info now.
|
// No need to migrate any relays, delete usage info now.
|
||||||
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
||||||
delete(cm.relayUsed, idx)
|
delete(n.relayUsed, idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
||||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||||
|
|
||||||
for _, r := range relayFor {
|
for _, r := range relayFor {
|
||||||
@@ -224,51 +414,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
var relayFrom netip.Addr
|
var relayFrom netip.Addr
|
||||||
var relayTo netip.Addr
|
var relayTo netip.Addr
|
||||||
switch {
|
switch {
|
||||||
case ok:
|
case ok && existing.State == Established:
|
||||||
switch existing.State {
|
|
||||||
case Established, PeerRequested, Disestablished:
|
|
||||||
// This relay already exists in newhostinfo, then do nothing.
|
// This relay already exists in newhostinfo, then do nothing.
|
||||||
continue
|
continue
|
||||||
case Requested:
|
case ok && existing.State == Requested:
|
||||||
// The relay exists in a Requested state; re-send the request
|
// The relay exists in a Requested state; re-send the request
|
||||||
index = existing.LocalIndex
|
index = existing.LocalIndex
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = cm.intf.myVpnAddrs[0]
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
relayTo = existing.PeerAddr
|
relayTo = existing.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = existing.PeerAddr
|
relayFrom = existing.PeerAddr
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case !ok:
|
case !ok:
|
||||||
cm.relayUsedLock.RLock()
|
n.relayUsedLock.RLock()
|
||||||
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
|
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
|
||||||
// The relay hasn't been used; don't migrate it.
|
// The relay hasn't been used; don't migrate it.
|
||||||
cm.relayUsedLock.RUnlock()
|
n.relayUsedLock.RUnlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cm.relayUsedLock.RUnlock()
|
n.relayUsedLock.RUnlock()
|
||||||
// The relay doesn't exist at all; create some relay state and send the request.
|
// The relay doesn't exist at all; create some relay state and send the request.
|
||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = cm.intf.myVpnAddrs[0]
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
relayTo = r.PeerAddr
|
relayTo = r.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = r.PeerAddr
|
relayFrom = r.PeerAddr
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,12 +466,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
switch newhostinfo.GetCert().Certificate.Version() {
|
switch newhostinfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !relayFrom.Is4() {
|
if !relayFrom.Is4() {
|
||||||
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !relayTo.Is4() {
|
if !relayTo.Is4() {
|
||||||
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,16 +483,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||||
default:
|
default:
|
||||||
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
|
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||||
} else {
|
} else {
|
||||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
cm.l.WithFields(logrus.Fields{
|
n.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom": req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo": req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
@@ -318,45 +503,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
||||||
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
|
n.hostMap.RLock()
|
||||||
cm.hostMap.RLock()
|
defer n.hostMap.RUnlock()
|
||||||
defer cm.hostMap.RUnlock()
|
|
||||||
|
|
||||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
hostinfo := n.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
|
||||||
|
delete(n.pendingDeletion, localIndex)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.isInvalidCertificate(now, hostinfo) {
|
if n.isInvalidCertificate(now, hostinfo) {
|
||||||
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
return closeTunnel, hostinfo, nil
|
return closeTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
mainHostInfo := true
|
mainHostInfo := true
|
||||||
if primary != nil && primary != hostinfo {
|
if primary != nil && primary != hostinfo {
|
||||||
mainHostInfo = false
|
mainHostInfo = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for traffic on this hostinfo
|
// Check for traffic on this hostinfo
|
||||||
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
|
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
|
||||||
|
|
||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
hostinfo.pendingDeletion.Store(false)
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if n.shouldSwapPrimary(hostinfo, primary) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
} else {
|
} else {
|
||||||
// migrate the relays to the primary, if in use.
|
// migrate the relays to the primary, if in use.
|
||||||
@@ -364,55 +550,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||||
|
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
// Send a punch packet to keep the NAT state alive
|
// Send a punch packet to keep the NAT state alive
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return decision, hostinfo, primary
|
return decision, hostinfo, primary
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.pendingDeletion.Load() {
|
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||||
Info("Tunnel status")
|
Info("Tunnel status")
|
||||||
|
|
||||||
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
|
||||||
if isInactive {
|
|
||||||
// Tunnel is inactive, tear it down
|
|
||||||
hostinfo.logger(cm.l).
|
|
||||||
WithField("inactiveDuration", inactiveFor).
|
|
||||||
WithField("primary", mainHostInfo).
|
|
||||||
Info("Dropping tunnel due to inactivity")
|
|
||||||
|
|
||||||
return closeTunnel, hostinfo, primary
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||||
// Just maintain NAT state if configured to do so.
|
// Just maintain NAT state if configured to do so.
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
if n.punchy.GetTargetEverything() {
|
||||||
// This is similar to the old punchy behavior with a slight optimization.
|
// This is similar to the old punchy behavior with a slight optimization.
|
||||||
// We aren't receiving traffic but we are sending it, punch on all known
|
// We aren't receiving traffic but we are sending it, punch on all known
|
||||||
// ips in case we need to re-prime NAT state
|
// ips in case we need to re-prime NAT state
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
@@ -421,33 +598,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.pendingDeletion.Store(true)
|
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
|
||||||
return decision, hostinfo, nil
|
return decision, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
|
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||||
if cm.dropInactive.Load() == false {
|
|
||||||
// We aren't configured to drop inactive tunnels
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
inactiveDuration := now.Sub(hostinfo.lastUsed)
|
|
||||||
if inactiveDuration < cm.getInactivityTimeout() {
|
|
||||||
// It's not considered inactive
|
|
||||||
return inactiveDuration, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tunnel is inactive
|
|
||||||
return inactiveDuration, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|
||||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||||
// Let's sort this out.
|
// Let's sort this out.
|
||||||
@@ -455,80 +616,73 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
||||||
// vpn addr is static across all tunnels for this host pair so lets
|
// vpn addr is static across all tunnels for this host pair so lets
|
||||||
// use that to determine if we should consider swapping.
|
// use that to determine if we should consider swapping.
|
||||||
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
|
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
||||||
// Their primary vpn addr is less than mine. Do not swap.
|
// Their primary vpn addr is less than mine. Do not swap.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||||
cm.hostMap.Lock()
|
n.hostMap.Lock()
|
||||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||||
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||||
cm.hostMap.unlockedMakePrimary(current)
|
n.hostMap.unlockedMakePrimary(current)
|
||||||
}
|
}
|
||||||
cm.hostMap.Unlock()
|
n.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||||
// check and return true.
|
// check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := n.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(n.l).WithError(err).
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
if !cm.punchy.GetPunch() {
|
if !n.punchy.GetPunch() {
|
||||||
// Punching is disabled
|
// Punching is disabled
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
if n.punchy.GetTargetEverything() {
|
||||||
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||||
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
n.metricsTxPunchy.Inc(1)
|
||||||
// would lose the ability to notify us and punchy.respond would become unreliable.
|
_ = n.intf.outside.WriteTo([]byte{1}, addr)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
|
||||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
|
||||||
cm.metricsTxPunchy.Inc(1)
|
|
||||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
} else if hostinfo.remote.IsValid() {
|
||||||
cm.metricsTxPunchy.Inc(1)
|
n.metricsTxPunchy.Inc(1)
|
||||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
_ = n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := n.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
myCrt := cs.getCertificate(curCrt.Version())
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
@@ -536,9 +690,9 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("reason", "local certificate is not current").
|
WithField("reason", "local certificate is not current").
|
||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
|
|||||||
addrMap: map[netip.Addr]*RemoteList{},
|
addrMap: map[netip.Addr]*RemoteList{},
|
||||||
queryChan: make(chan netip.Addr, 10),
|
queryChan: make(chan netip.Addr, 10),
|
||||||
}
|
}
|
||||||
lighthouses := []netip.Addr{}
|
lighthouses := map[netip.Addr]struct{}{}
|
||||||
staticList := map[netip.Addr]struct{}{}
|
staticList := map[netip.Addr]struct{}{}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lighthouses)
|
lh.lighthouses.Store(&lighthouses)
|
||||||
@@ -63,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.out.Load())
|
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// Do a final traffic check tick, the host should now be removed
|
// Do a final traffic check tick, the host should now be removed
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
|
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// We saw traffic, should no longer be pending deletion
|
// We saw traffic, should no longer be pending deletion
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
|
||||||
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
|
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
|
||||||
|
|
||||||
// Very incomplete mock objects
|
|
||||||
hostMap := newHostMap(l)
|
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
|
||||||
|
|
||||||
cs := &CertState{
|
|
||||||
initiatingVersion: cert.Version1,
|
|
||||||
privateKey: []byte{},
|
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
|
||||||
v1HandshakeBytes: []byte{},
|
|
||||||
}
|
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
|
||||||
ifce := &Interface{
|
|
||||||
hostMap: hostMap,
|
|
||||||
inside: &test.NoopTun{},
|
|
||||||
outside: &udp.NoopConn{},
|
|
||||||
firewall: &Firewall{},
|
|
||||||
lightHouse: lh,
|
|
||||||
pki: &PKI{},
|
|
||||||
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
ifce.pki.cs.Store(cs)
|
|
||||||
|
|
||||||
// Create manager
|
|
||||||
conf := config.NewC(l)
|
|
||||||
conf.Settings["tunnels"] = map[string]any{
|
|
||||||
"drop_inactive": true,
|
|
||||||
}
|
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
|
||||||
assert.True(t, nc.dropInactive.Load())
|
|
||||||
nc.intf = ifce
|
|
||||||
|
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
|
||||||
hostinfo := &HostInfo{
|
|
||||||
vpnAddrs: vpnAddrs,
|
|
||||||
localIndexId: 1099,
|
|
||||||
remoteIndexId: 9901,
|
|
||||||
}
|
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
|
||||||
myCert: &dummyCert{version: cert.Version1},
|
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
}
|
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
|
||||||
|
|
||||||
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
|
|
||||||
nc.Out(hostinfo)
|
|
||||||
nc.In(hostinfo)
|
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
assert.True(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
|
|
||||||
assert.Equal(t, tryRehandshake, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
|
|
||||||
assert.Equal(t, doNothing, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
// Do another traffic check tick, should still not be pending deletion
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
|
|
||||||
assert.Equal(t, doNothing, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
|
||||||
|
|
||||||
// Finally advance beyond the inactivity timeout
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
|
|
||||||
assert.Equal(t, closeTunnel, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
|
|||||||
return d.publicKey
|
return d.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
|
||||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyCert) Signature() []byte {
|
func (d *dummyCert) Signature() []byte {
|
||||||
return d.signature
|
return d.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
|
const ReplayWindow = 1024
|
||||||
// 4092 should be sufficient for 5Gbps
|
|
||||||
const ReplayWindow = 4096
|
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
|
|||||||
60
control.go
60
control.go
@@ -2,11 +2,9 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -15,16 +13,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RunState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
Stopped RunState = 0 // The control has yet to be started
|
|
||||||
Started RunState = 1 // The control has been started
|
|
||||||
Stopping RunState = 2 // The control is stopping
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrAlreadyStarted = errors.New("nebula is already started")
|
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -38,9 +26,6 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
stateLock sync.Mutex
|
|
||||||
state RunState
|
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -49,7 +34,6 @@ type Control struct {
|
|||||||
statsStart func()
|
statsStart func()
|
||||||
dnsStart func()
|
dnsStart func()
|
||||||
lighthouseStart func()
|
lighthouseStart func()
|
||||||
connectionManagerStart func(context.Context)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControlHostInfo struct {
|
type ControlHostInfo struct {
|
||||||
@@ -64,21 +48,10 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call.
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
// The returned function can be used to wait for nebula to fully stop.
|
func (c *Control) Start() {
|
||||||
func (c *Control) Start() (func(), error) {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Stopped {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, ErrAlreadyStarted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
err := c.f.activate()
|
c.f.activate()
|
||||||
if err != nil {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -90,41 +63,20 @@ func (c *Control) Start() (func(), error) {
|
|||||||
if c.dnsStart != nil {
|
if c.dnsStart != nil {
|
||||||
go c.dnsStart()
|
go c.dnsStart()
|
||||||
}
|
}
|
||||||
if c.connectionManagerStart != nil {
|
|
||||||
go c.connectionManagerStart(c.ctx)
|
|
||||||
}
|
|
||||||
if c.lighthouseStart != nil {
|
if c.lighthouseStart != nil {
|
||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.state = Started
|
c.f.run()
|
||||||
c.stateLock.Unlock()
|
|
||||||
return c.f.run(c.ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) State() RunState {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
defer c.stateLock.Unlock()
|
|
||||||
return c.state
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Started {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
// We are stopping or stopped already
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.state = Stopping
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -133,7 +85,7 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.state = Stopped
|
c.l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp},
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp2},
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
curIndexes := len(myControl.GetHostmap().Indexes)
|
curIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
for curIndexes >= start {
|
for curIndexes >= start {
|
||||||
curIndexes = len(myControl.GetHostmap().Indexes)
|
curIndexes = len(myControl.GetHostmap().Indexes)
|
||||||
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
||||||
|
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@@ -1052,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
t.Log("Stand up a tunnel between me and them")
|
t.Log("Stand up a tunnel between me and them")
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Renew their certificate and spin until mine sees it")
|
r.Log("Renew their certificate and spin until mine sees it")
|
||||||
|
|||||||
@@ -700,7 +700,6 @@ func (r *R) FlushAll() {
|
|||||||
r.Unlock()
|
r.Unlock()
|
||||||
panic("Can't FlushAll for host: " + p.To.String())
|
panic("Can't FlushAll for host: " + p.To.String())
|
||||||
}
|
}
|
||||||
receiver.InjectUDPPacket(p)
|
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
//go:build e2e_testing
|
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package e2e
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/slackhq/nebula/cert_test"
|
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
r.Log("Go inactive and wait for the tunnels to get dropped")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
myIndexes := len(myControl.GetHostmap().Indexes)
|
|
||||||
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
|
||||||
if myIndexes == 0 && theirIndexes == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
|
||||||
if since > time.Second*30 {
|
|
||||||
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
@@ -132,13 +132,6 @@ listen:
|
|||||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||||
# default is 64, does not support reload
|
# default is 64, does not support reload
|
||||||
#batch: 64
|
#batch: 64
|
||||||
|
|
||||||
# Control batching between UDP and TUN pipelines
|
|
||||||
#batch:
|
|
||||||
# inbound_size: 32 # packets to queue from UDP before handing to workers
|
|
||||||
# outbound_size: 32 # packets to queue from TUN before handing to workers
|
|
||||||
# flush_interval: 50us # flush partially filled batches after this duration
|
|
||||||
# max_outstanding: 1028 # batches buffered per routine on each channel
|
|
||||||
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
|
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
|
||||||
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
|
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
|
||||||
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
|
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
|
||||||
@@ -345,18 +338,6 @@ logging:
|
|||||||
# after receiving the response for lighthouse queries
|
# after receiving the response for lighthouse queries
|
||||||
#trigger_buffer: 64
|
#trigger_buffer: 64
|
||||||
|
|
||||||
# Tunnel manager settings
|
|
||||||
#tunnels:
|
|
||||||
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
|
|
||||||
# elapsed.
|
|
||||||
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
|
|
||||||
# This setting is reloadable
|
|
||||||
#drop_inactive: false
|
|
||||||
|
|
||||||
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
|
|
||||||
# inactive and eligible to be dropped.
|
|
||||||
# This setting is reloadable
|
|
||||||
#inactivity_timeout: 10m
|
|
||||||
|
|
||||||
# Nebula security group configuration
|
# Nebula security group configuration
|
||||||
firewall:
|
firewall:
|
||||||
|
|||||||
197
firewall_test.go
197
firewall_test.go
@@ -68,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
@@ -95,24 +92,12 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
@@ -132,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
anyIp6, err := netip.ParsePrefix("::/0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -221,82 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
LocalPort: 10,
|
|
||||||
RemotePort: 90,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c := dummyCert{
|
|
||||||
name: "host1",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &cert.CachedCertificate{
|
|
||||||
Certificate: &c,
|
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
|
||||||
}
|
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
|
|
||||||
// Drop outbound
|
|
||||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
|
||||||
// Allow inbound
|
|
||||||
resetConntrack(fw)
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
// Allow outbound because conntrack
|
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
|
||||||
|
|
||||||
// test remote mismatch
|
|
||||||
oldRemote := p.RemoteAddr
|
|
||||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
|
||||||
p.RemoteAddr = oldRemote
|
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
f := &Firewall{}
|
f := &Firewall{}
|
||||||
ft := FirewallTable{
|
ft := FirewallTable{
|
||||||
@@ -306,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||||
|
|
||||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
@@ -341,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{},
|
|
||||||
}
|
|
||||||
ip := netip.MustParsePrefix("fd99::99/128")
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -363,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -388,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -424,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass on name", func(b *testing.B) {
|
b.Run("pass on name", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -593,42 +447,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
LocalPort: 1,
|
|
||||||
RemotePort: 1,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
network := netip.MustParsePrefix("fd12::34/120")
|
|
||||||
c := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host-owner",
|
|
||||||
networks: []netip.Prefix{network},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &c,
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
|
||||||
}
|
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
|
||||||
|
|
||||||
// Test a remote address match
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
@@ -909,21 +727,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
|
|||||||
10
go.mod
10
go.mod
@@ -1,9 +1,11 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.25
|
go 1.23.0
|
||||||
|
|
||||||
|
toolchain go1.24.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.2
|
dario.cat/mergo v1.0.1
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
@@ -21,7 +23,7 @@ require (
|
|||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.0
|
||||||
golang.org/x/crypto v0.37.0
|
golang.org/x/crypto v0.37.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.39.0
|
golang.org/x/net v0.39.0
|
||||||
@@ -46,7 +48,7 @@ require (
|
|||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.62.0 // indirect
|
github.com/prometheus/common v0.62.0 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
golang.org/x/mod v0.23.0 // indirect
|
golang.org/x/mod v0.23.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.30.0 // indirect
|
golang.org/x/tools v0.30.0 // indirect
|
||||||
|
|||||||
12
go.sum
12
go.sum
@@ -1,6 +1,6 @@
|
|||||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
@@ -145,10 +145,10 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
|
|||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
lastHandshakeTime: hs.Details.Time,
|
lastHandshakeTime: hs.Details.Time,
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
@@ -457,9 +457,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.ResetBlockedRemotes()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -652,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
||||||
@@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.ResetBlockedRemotes()
|
||||||
f.metricHandshakes.Update(duration)
|
f.metricHandshakes.Update(duration)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -450,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
|
|||||||
vpnAddrs: []netip.Addr{vpnAddr},
|
vpnAddrs: []netip.Addr{vpnAddr},
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
|
|||||||
38
hostmap.go
38
hostmap.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,10 +16,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// const ProbeLen = 100
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||||
const MaxRemotes = 10
|
const MaxRemotes = 10
|
||||||
|
const maxRecvError = 4
|
||||||
|
|
||||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||||
@@ -67,7 +68,7 @@ type HostMap struct {
|
|||||||
type RelayState struct {
|
type RelayState struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
|
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
||||||
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
||||||
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
||||||
// the RelayState Lock held)
|
// the RelayState Lock held)
|
||||||
@@ -78,12 +79,7 @@ type RelayState struct {
|
|||||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
for idx, val := range rs.relays {
|
delete(rs.relays, ip)
|
||||||
if val == ip {
|
|
||||||
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
||||||
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
|||||||
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
if !slices.Contains(rs.relays, ip) {
|
rs.relays[ip] = struct{}{}
|
||||||
rs.relays = append(rs.relays, ip)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||||
ret := make([]netip.Addr, len(rs.relays))
|
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
copy(ret, rs.relays)
|
ret := make([]netip.Addr, 0, len(rs.relays))
|
||||||
|
for ip := range rs.relays {
|
||||||
|
ret = append(ret, ip)
|
||||||
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,6 +220,7 @@ type HostInfo struct {
|
|||||||
// The host may have other vpn addresses that are outside our
|
// The host may have other vpn addresses that are outside our
|
||||||
// vpn networks but were removed because they are not usable
|
// vpn networks but were removed because they are not usable
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
recvError atomic.Uint32
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
networks *bart.Lite
|
networks *bart.Lite
|
||||||
@@ -253,14 +250,6 @@ type HostInfo struct {
|
|||||||
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
||||||
// Synchronised via hostmap lock and not the hostinfo lock.
|
// Synchronised via hostmap lock and not the hostinfo lock.
|
||||||
next, prev *HostInfo
|
next, prev *HostInfo
|
||||||
|
|
||||||
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
|
|
||||||
in, out, pendingDeletion atomic.Bool
|
|
||||||
|
|
||||||
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
|
|
||||||
// This value will be behind against actual tunnel utilization in the hot path.
|
|
||||||
// This should only be used by the ConnectionManagers ticker routine.
|
|
||||||
lastUsed time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ViaSender struct {
|
type ViaSender struct {
|
||||||
@@ -730,6 +719,13 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *HostInfo) RecvErrorExceeded() bool {
|
||||||
|
if i.recvError.Add(1) >= maxRecvError {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
// Simple case, no CIDRTree needed
|
// Simple case, no CIDRTree needed
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHostMap_MakePrimary(t *testing.T) {
|
func TestHostMap_MakePrimary(t *testing.T) {
|
||||||
@@ -216,31 +215,3 @@ func TestHostMap_reload(t *testing.T) {
|
|||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||||
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostMap_RelayState(t *testing.T) {
|
|
||||||
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
|
||||||
a1 := netip.MustParseAddr("::1")
|
|
||||||
a2 := netip.MustParseAddr("2001::1")
|
|
||||||
|
|
||||||
h1.relayState.InsertRelayTo(a1)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
h1.relayState.InsertRelayTo(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
|
|
||||||
// Ensure that the first relay added is the first one returned in the copy
|
|
||||||
currentRelays := h1.relayState.CopyRelayIps()
|
|
||||||
require.Len(t, currentRelays, 2)
|
|
||||||
assert.Equal(t, a1, currentRelays[0])
|
|
||||||
|
|
||||||
// Deleting the last one in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
|
|
||||||
// Deleting an element not in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
|
|
||||||
// Deleting the only element in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a1)
|
|
||||||
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|||||||
68
inside.go
68
inside.go
@@ -11,19 +11,19 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
}
|
}
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore local broadcast packets
|
// Ignore local broadcast packets
|
||||||
if f.dropLocalBroadcast {
|
if f.dropLocalBroadcast {
|
||||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
}
|
}
|
||||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||||
// routes packets from the nebula addr to the nebula addr through the loopback device.
|
// routes packets from the nebula addr to the nebula addr through the loopback device.
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore multicast packets
|
// Ignore multicast packets
|
||||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||||
@@ -59,18 +59,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||||
}
|
}
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ready {
|
if !ready {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
|
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||||
}
|
|
||||||
|
|
||||||
|
} else {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).
|
||||||
@@ -78,7 +78,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
Debugln("dropping outbound packet")
|
Debugln("dropping outbound packet")
|
||||||
}
|
}
|
||||||
return false
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||||
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||||
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
|
|||||||
|
|
||||||
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
f.messageMetrics.Tx(t, st, 1)
|
f.messageMetrics.Tx(t, st, 1)
|
||||||
_ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
|
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
|
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
|
||||||
f.messageMetrics.Tx(t, st, 1)
|
f.messageMetrics.Tx(t, st, 1)
|
||||||
_ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
|
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
||||||
@@ -288,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
c := via.ConnectionState.messageCounter.Add(1)
|
c := via.ConnectionState.messageCounter.Add(1)
|
||||||
|
|
||||||
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
||||||
f.connectionManager.Out(via)
|
f.connectionManager.Out(via.localIndexId)
|
||||||
|
|
||||||
// Authenticate the header and payload, but do not encrypt for this message type.
|
// Authenticate the header and payload, but do not encrypt for this message type.
|
||||||
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
||||||
@@ -331,12 +331,9 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
|
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||||
// when the payload has been handed to a caller-provided queue (meaning the
|
|
||||||
// caller is responsible for flushing it later).
|
|
||||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
|
|
||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||||
fullOut := out
|
fullOut := out
|
||||||
@@ -359,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||||
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||||
f.connectionManager.Out(hostinfo)
|
f.connectionManager.Out(hostinfo.localIndexId)
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||||
// all our addrs and enable a faster roaming.
|
// all our addrs and enable a faster roaming.
|
||||||
@@ -383,28 +380,22 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
WithField("udpAddr", remote).WithField("counter", c).
|
WithField("udpAddr", remote).WithField("counter", c).
|
||||||
WithField("attemptedCounter", c).
|
WithField("attemptedCounter", c).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dest := remote
|
if remote.IsValid() {
|
||||||
if !dest.IsValid() {
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
dest = hostinfo.remote
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.IsValid() {
|
|
||||||
if queue != nil {
|
|
||||||
queue(dest, len(out))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
err = f.writers[q].WriteTo(out, dest)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", dest).Error("Failed to write outgoing packet")
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
}
|
}
|
||||||
return false
|
} else if hostinfo.remote.IsValid() {
|
||||||
|
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).
|
||||||
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
@@ -416,6 +407,5 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
476
interface.go
476
interface.go
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,34 +18,22 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const mtu = 9001
|
||||||
mtu = 9001
|
|
||||||
|
|
||||||
inboundBatchSizeDefault = 128
|
|
||||||
outboundBatchSizeDefault = 64
|
|
||||||
batchFlushIntervalDefault = 12 * time.Microsecond
|
|
||||||
maxOutstandingBatchesDefault = 8
|
|
||||||
sendBatchSizeDefault = 64
|
|
||||||
maxPendingPacketsDefault = 32
|
|
||||||
maxPendingBytesDefault = 64 * 1024
|
|
||||||
maxSendBufPerRoutineDefault = 16
|
|
||||||
)
|
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
Outside udp.Conn
|
Outside udp.Conn
|
||||||
Inside overlay.Device
|
Inside overlay.Device
|
||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
ServeDns bool
|
ServeDns bool
|
||||||
HandshakeManager *HandshakeManager
|
HandshakeManager *HandshakeManager
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
connectionManager *connectionManager
|
checkInterval time.Duration
|
||||||
|
pendingDeletionInterval time.Duration
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
@@ -59,20 +47,9 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
BatchConfig BatchConfig
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type BatchConfig struct {
|
|
||||||
InboundBatchSize int
|
|
||||||
OutboundBatchSize int
|
|
||||||
FlushInterval time.Duration
|
|
||||||
MaxOutstandingPerChan int
|
|
||||||
MaxPendingPackets int
|
|
||||||
MaxPendingBytes int
|
|
||||||
MaxSendBuffersPerChan int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside udp.Conn
|
outside udp.Conn
|
||||||
@@ -110,165 +87,12 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
inPool sync.Pool
|
|
||||||
inbound []chan *packetBatch
|
|
||||||
|
|
||||||
outPool sync.Pool
|
|
||||||
outbound []chan *outboundBatch
|
|
||||||
|
|
||||||
packetBatchPool sync.Pool
|
|
||||||
outboundBatchPool sync.Pool
|
|
||||||
|
|
||||||
sendPool sync.Pool
|
|
||||||
sendBufCache [][]*[]byte
|
|
||||||
sendBatchSize int
|
|
||||||
|
|
||||||
inboundBatchSize int
|
|
||||||
outboundBatchSize int
|
|
||||||
batchFlushInterval time.Duration
|
|
||||||
maxOutstandingPerChan int
|
|
||||||
maxPendingPackets int
|
|
||||||
maxPendingBytes int
|
|
||||||
maxSendBufPerRoutine int
|
|
||||||
}
|
|
||||||
|
|
||||||
type outboundSend struct {
|
|
||||||
buf *[]byte
|
|
||||||
length int
|
|
||||||
addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
type packetBatch struct {
|
|
||||||
packets []*packet.Packet
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPacketBatch(capacity int) *packetBatch {
|
|
||||||
return &packetBatch{
|
|
||||||
packets: make([]*packet.Packet, 0, capacity),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *packetBatch) add(p *packet.Packet) {
|
|
||||||
b.packets = append(b.packets, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *packetBatch) reset() {
|
|
||||||
for i := range b.packets {
|
|
||||||
b.packets[i] = nil
|
|
||||||
}
|
|
||||||
b.packets = b.packets[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) getPacketBatch() *packetBatch {
|
|
||||||
if v := f.packetBatchPool.Get(); v != nil {
|
|
||||||
b := v.(*packetBatch)
|
|
||||||
b.reset()
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return newPacketBatch(f.inboundBatchSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) releasePacketBatch(b *packetBatch) {
|
|
||||||
b.reset()
|
|
||||||
f.packetBatchPool.Put(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
type outboundBatch struct {
|
|
||||||
payloads []*[]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOutboundBatch(capacity int) *outboundBatch {
|
|
||||||
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *outboundBatch) add(buf *[]byte) {
|
|
||||||
b.payloads = append(b.payloads, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *outboundBatch) reset() {
|
|
||||||
for i := range b.payloads {
|
|
||||||
b.payloads[i] = nil
|
|
||||||
}
|
|
||||||
b.payloads = b.payloads[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) getOutboundBatch() *outboundBatch {
|
|
||||||
if v := f.outboundBatchPool.Get(); v != nil {
|
|
||||||
b := v.(*outboundBatch)
|
|
||||||
b.reset()
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return newOutboundBatch(f.outboundBatchSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
|
|
||||||
b.reset()
|
|
||||||
f.outboundBatchPool.Put(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) getSendBuffer(q int) *[]byte {
|
|
||||||
cache := f.sendBufCache[q]
|
|
||||||
if n := len(cache); n > 0 {
|
|
||||||
buf := cache[n-1]
|
|
||||||
f.sendBufCache[q] = cache[:n-1]
|
|
||||||
*buf = (*buf)[:0]
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
if v := f.sendPool.Get(); v != nil {
|
|
||||||
buf := v.(*[]byte)
|
|
||||||
*buf = (*buf)[:0]
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
b := make([]byte, mtu)
|
|
||||||
return &b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) releaseSendBuffer(q int, buf *[]byte) {
|
|
||||||
if buf == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*buf = (*buf)[:0]
|
|
||||||
cache := f.sendBufCache[q]
|
|
||||||
if len(cache) < f.maxSendBufPerRoutine {
|
|
||||||
f.sendBufCache[q] = append(cache, buf)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.sendPool.Put(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) flushSendQueue(q int, pending *[]outboundSend, pendingBytes *int) {
|
|
||||||
if len(*pending) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
batch := make([]udp.BatchPacket, len(*pending))
|
|
||||||
for i, entry := range *pending {
|
|
||||||
batch[i] = udp.BatchPacket{
|
|
||||||
Payload: (*entry.buf)[:entry.length],
|
|
||||||
Addr: entry.addr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sent, err := f.writers[q].WriteBatch(batch)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entry := range *pending {
|
|
||||||
f.releaseSendBuffer(q, entry.buf)
|
|
||||||
}
|
|
||||||
*pending = (*pending)[:0]
|
|
||||||
if pendingBytes != nil {
|
|
||||||
*pendingBytes = 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -333,34 +157,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
if c.Firewall == nil {
|
if c.Firewall == nil {
|
||||||
return nil, errors.New("no firewall rules")
|
return nil, errors.New("no firewall rules")
|
||||||
}
|
}
|
||||||
if c.connectionManager == nil {
|
|
||||||
return nil, errors.New("no connection manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
|
|
||||||
bc := c.BatchConfig
|
|
||||||
if bc.InboundBatchSize <= 0 {
|
|
||||||
bc.InboundBatchSize = inboundBatchSizeDefault
|
|
||||||
}
|
|
||||||
if bc.OutboundBatchSize <= 0 {
|
|
||||||
bc.OutboundBatchSize = outboundBatchSizeDefault
|
|
||||||
}
|
|
||||||
if bc.FlushInterval <= 0 {
|
|
||||||
bc.FlushInterval = batchFlushIntervalDefault
|
|
||||||
}
|
|
||||||
if bc.MaxOutstandingPerChan <= 0 {
|
|
||||||
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
|
|
||||||
}
|
|
||||||
if bc.MaxPendingPackets <= 0 {
|
|
||||||
bc.MaxPendingPackets = maxPendingPacketsDefault
|
|
||||||
}
|
|
||||||
if bc.MaxPendingBytes <= 0 {
|
|
||||||
bc.MaxPendingBytes = maxPendingBytesDefault
|
|
||||||
}
|
|
||||||
if bc.MaxSendBuffersPerChan <= 0 {
|
|
||||||
bc.MaxSendBuffersPerChan = maxSendBufPerRoutineDefault
|
|
||||||
}
|
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
@@ -383,7 +181,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
@@ -393,54 +191,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
|
|
||||||
inbound: make([]chan *packetBatch, c.routines),
|
|
||||||
outbound: make([]chan *outboundBatch, c.routines),
|
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
|
|
||||||
inboundBatchSize: bc.InboundBatchSize,
|
|
||||||
outboundBatchSize: bc.OutboundBatchSize,
|
|
||||||
batchFlushInterval: bc.FlushInterval,
|
|
||||||
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
|
|
||||||
maxPendingPackets: bc.MaxPendingPackets,
|
|
||||||
maxPendingBytes: bc.MaxPendingBytes,
|
|
||||||
maxSendBufPerRoutine: bc.MaxSendBuffersPerChan,
|
|
||||||
sendBatchSize: bc.OutboundBatchSize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < c.routines; i++ {
|
|
||||||
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
|
|
||||||
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifce.inPool = sync.Pool{New: func() any {
|
|
||||||
return packet.New()
|
|
||||||
}}
|
|
||||||
|
|
||||||
ifce.outPool = sync.Pool{New: func() any {
|
|
||||||
t := make([]byte, mtu)
|
|
||||||
return &t
|
|
||||||
}}
|
|
||||||
|
|
||||||
ifce.packetBatchPool = sync.Pool{New: func() any {
|
|
||||||
return newPacketBatch(ifce.inboundBatchSize)
|
|
||||||
}}
|
|
||||||
|
|
||||||
ifce.outboundBatchPool = sync.Pool{New: func() any {
|
|
||||||
return newOutboundBatch(ifce.outboundBatchSize)
|
|
||||||
}}
|
|
||||||
|
|
||||||
ifce.sendPool = sync.Pool{New: func() any {
|
|
||||||
buf := make([]byte, mtu)
|
|
||||||
return &buf
|
|
||||||
}}
|
|
||||||
ifce.sendBufCache = make([][]*[]byte, c.routines)
|
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
@@ -448,7 +206,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() error {
|
func (f *Interface) activate() {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -469,44 +227,33 @@ func (f *Interface) activate() error {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
func (f *Interface) run() {
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) run(c context.Context) (func(), error) {
|
|
||||||
for i := 0; i < f.routines; i++ {
|
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
f.wg.Add(1)
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.listenIn(f.readers[i], i)
|
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.workerIn(i, c)
|
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
|
||||||
f.wg.Add(1)
|
|
||||||
go f.workerOut(i, c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.wg.Wait, nil
|
// Launch n queues to read packets from tun dev
|
||||||
|
for i := 0; i < f.routines; i++ {
|
||||||
|
go f.listenIn(f.readers[i], i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -514,176 +261,41 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
batch := f.getPacketBatch()
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lastFlush := time.Now()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
|
plaintext := make([]byte, udp.MTU)
|
||||||
|
h := &header.H{}
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
flush := func(force bool) {
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
if len(batch.packets) == 0 {
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
if force {
|
|
||||||
f.releasePacketBatch(batch)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.inbound[i] <- batch
|
|
||||||
batch = f.getPacketBatch()
|
|
||||||
lastFlush = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
|
||||||
p := f.inPool.Get().(*packet.Packet)
|
|
||||||
p.Payload = p.Payload[:mtu]
|
|
||||||
copy(p.Payload, payload)
|
|
||||||
p.Payload = p.Payload[:len(payload)]
|
|
||||||
p.Addr = fromUdpAddr
|
|
||||||
batch.add(p)
|
|
||||||
|
|
||||||
if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
|
||||||
flush(false)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(batch.packets) > 0 {
|
|
||||||
f.inbound[i] <- batch
|
|
||||||
} else {
|
|
||||||
f.releasePacketBatch(batch)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && !f.closed.Load() {
|
|
||||||
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
|
||||||
//TODO: Trigger Control to close
|
|
||||||
}
|
|
||||||
|
|
||||||
f.l.Debugf("underlay reader %v is done", i)
|
|
||||||
f.wg.Done()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
batch := f.getOutboundBatch()
|
packet := make([]byte, mtu)
|
||||||
lastFlush := time.Now()
|
out := make([]byte, mtu)
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
flush := func(force bool) {
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
if len(batch.payloads) == 0 {
|
|
||||||
if force {
|
|
||||||
f.releaseOutboundBatch(batch)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.outbound[i] <- batch
|
|
||||||
batch = f.getOutboundBatch()
|
|
||||||
lastFlush = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
p := f.outPool.Get().(*[]byte)
|
n, err := reader.Read(packet)
|
||||||
*p = (*p)[:mtu]
|
|
||||||
n, err := reader.Read(*p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
|
||||||
//TODO: Trigger Control to close
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
*p = (*p)[:n]
|
|
||||||
batch.add(p)
|
|
||||||
|
|
||||||
if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
|
||||||
flush(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(batch.payloads) > 0 {
|
|
||||||
f.outbound[i] <- batch
|
|
||||||
} else {
|
|
||||||
f.releaseOutboundBatch(batch)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.l.Debugf("overlay reader %v is done", i)
|
|
||||||
f.wg.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) workerIn(i int, ctx context.Context) {
|
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
fwPacket2 := &firewall.Packet{}
|
|
||||||
nb2 := make([]byte, 12, 12)
|
|
||||||
result2 := make([]byte, mtu)
|
|
||||||
h := &header.H{}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case batch := <-f.inbound[i]:
|
|
||||||
for _, p := range batch.packets {
|
|
||||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
|
||||||
p.Payload = p.Payload[:mtu]
|
|
||||||
f.inPool.Put(p)
|
|
||||||
}
|
|
||||||
f.releasePacketBatch(batch)
|
|
||||||
case <-ctx.Done():
|
|
||||||
f.wg.Done()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) workerOut(i int, ctx context.Context) {
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
fwPacket1 := &firewall.Packet{}
|
|
||||||
nb1 := make([]byte, 12, 12)
|
|
||||||
pending := make([]outboundSend, 0, f.sendBatchSize)
|
|
||||||
pendingBytes := 0
|
|
||||||
maxPendingPackets := f.maxPendingPackets
|
|
||||||
if maxPendingPackets <= 0 {
|
|
||||||
maxPendingPackets = f.sendBatchSize
|
|
||||||
}
|
|
||||||
maxPendingBytes := f.maxPendingBytes
|
|
||||||
if maxPendingBytes <= 0 {
|
|
||||||
maxPendingBytes = f.sendBatchSize * mtu
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case batch := <-f.outbound[i]:
|
|
||||||
for _, data := range batch.payloads {
|
|
||||||
sendBuf := f.getSendBuffer(i)
|
|
||||||
buf := (*sendBuf)[:0]
|
|
||||||
queue := func(addr netip.AddrPort, length int) {
|
|
||||||
if len(pending) >= maxPendingPackets || pendingBytes+length > maxPendingBytes {
|
|
||||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
|
||||||
}
|
|
||||||
pending = append(pending, outboundSend{
|
|
||||||
buf: sendBuf,
|
|
||||||
length: length,
|
|
||||||
addr: addr,
|
|
||||||
})
|
|
||||||
pendingBytes += length
|
|
||||||
if len(pending) >= f.sendBatchSize || pendingBytes >= maxPendingBytes {
|
|
||||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
|
|
||||||
if !sent {
|
|
||||||
f.releaseSendBuffer(i, sendBuf)
|
|
||||||
}
|
|
||||||
*data = (*data)[:mtu]
|
|
||||||
f.outPool.Put(data)
|
|
||||||
}
|
|
||||||
f.releaseOutboundBatch(batch)
|
|
||||||
if len(pending) > 0 {
|
|
||||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
if len(pending) > 0 {
|
|
||||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
|
||||||
}
|
|
||||||
f.wg.Done()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -836,7 +448,6 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
// Release the udp readers
|
|
||||||
for _, u := range f.writers {
|
for _, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -844,13 +455,6 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun readers
|
// Release the tun device
|
||||||
for _, u := range f.readers {
|
return f.inside.Close()
|
||||||
err := u.Close()
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while closing tun device")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
229
lighthouse.go
229
lighthouse.go
@@ -24,7 +24,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ErrHostNotKnown = errors.New("host not known")
|
var ErrHostNotKnown = errors.New("host not known")
|
||||||
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
|
||||||
|
|
||||||
type LightHouse struct {
|
type LightHouse struct {
|
||||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||||
@@ -57,7 +56,7 @@ type LightHouse struct {
|
|||||||
// staticList exists to avoid having a bool in each addrMap entry
|
// staticList exists to avoid having a bool in each addrMap entry
|
||||||
// since static should be rare
|
// since static should be rare
|
||||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||||
lighthouses atomic.Pointer[[]netip.Addr]
|
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||||
|
|
||||||
interval atomic.Int64
|
interval atomic.Int64
|
||||||
updateCancel context.CancelFunc
|
updateCancel context.CancelFunc
|
||||||
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
lighthouses := make([]netip.Addr, 0)
|
lighthouses := make(map[netip.Addr]struct{})
|
||||||
h.lighthouses.Store(&lighthouses)
|
h.lighthouses.Store(&lighthouses)
|
||||||
staticList := make(map[netip.Addr]struct{})
|
staticList := make(map[netip.Addr]struct{})
|
||||||
h.staticList.Store(&staticList)
|
h.staticList.Store(&staticList)
|
||||||
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
|||||||
return *lh.staticList.Load()
|
return *lh.staticList.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||||
return *lh.lighthouses.Load()
|
return *lh.lighthouses.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,12 +306,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("lighthouse.hosts") {
|
if initial || c.HasChanged("lighthouse.hosts") {
|
||||||
lhList, err := lh.parseLighthouses(c)
|
lhMap := make(map[netip.Addr]struct{})
|
||||||
|
err := lh.parseLighthouses(c, lhMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lhList)
|
lh.lighthouses.Store(&lhMap)
|
||||||
if !initial {
|
if !initial {
|
||||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||||
lh.l.Info("lighthouse.hosts has changed")
|
lh.l.Info("lighthouse.hosts has changed")
|
||||||
@@ -346,37 +346,36 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||||
if lh.amLighthouse && len(lhs) != 0 {
|
if lh.amLighthouse && len(lhs) != 0 {
|
||||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||||
}
|
}
|
||||||
out := make([]netip.Addr, len(lhs))
|
|
||||||
|
|
||||||
for i, host := range lhs {
|
for i, host := range lhs {
|
||||||
addr, err := netip.ParseAddr(host)
|
addr, err := netip.ParseAddr(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||||
}
|
}
|
||||||
out[i] = addr
|
lhMap[addr] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.amLighthouse && len(out) == 0 {
|
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||||
}
|
}
|
||||||
|
|
||||||
staticList := lh.GetStaticHostList()
|
staticList := lh.GetStaticHostList()
|
||||||
for i := range out {
|
for lhAddr, _ := range lhMap {
|
||||||
if _, ok := staticList[out[i]]; !ok {
|
if _, ok := staticList[lhAddr]; !ok {
|
||||||
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||||
@@ -487,7 +486,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
defer lh.Unlock()
|
defer lh.Unlock()
|
||||||
// Add an entry if we don't already have one
|
// Add an entry if we don't already have one
|
||||||
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
return lh.unlockedGetRemoteList(vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||||
@@ -520,15 +519,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||||
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
// First we check the static mapping
|
||||||
staticList := lh.GetStaticHostList()
|
// and do nothing if it is there
|
||||||
for _, addr := range allVpnAddrs {
|
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
||||||
if _, ok := staticList[addr]; ok {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// None of the VpnAddrs were present. Now we can do the deletes.
|
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -570,7 +565,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
|||||||
am.unlockedSetHostnamesResults(hr)
|
am.unlockedSetHostnamesResults(hr)
|
||||||
|
|
||||||
for _, addrPort := range hr.GetAddrs() {
|
for _, addrPort := range hr.GetAddrs() {
|
||||||
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
@@ -632,30 +627,23 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
|||||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedGetRemoteList assumes you have the lh lock
|
// unlockedGetRemoteList
|
||||||
|
// assumes you have the lh lock
|
||||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||||
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
am, ok := lh.addrMap[allAddrs[0]]
|
||||||
for i, addr := range allAddrs {
|
if !ok {
|
||||||
am, ok := lh.addrMap[addr]
|
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
||||||
if ok {
|
|
||||||
if i != 0 {
|
|
||||||
lh.addrMap[allAddrs[0]] = am
|
|
||||||
}
|
|
||||||
return am
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
|
||||||
for _, addr := range allAddrs {
|
for _, addr := range allAddrs {
|
||||||
lh.addrMap[addr] = am
|
lh.addrMap[addr] = am
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return am
|
return am
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Level >= logrus.TraceLevel {
|
||||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
||||||
Trace("remoteAllowList.Allow")
|
Trace("remoteAllowList.Allow")
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -710,24 +698,21 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
||||||
for i := range l {
|
|
||||||
if l[i] == vpnAddr {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
||||||
|
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
||||||
|
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range vpnAddrs {
|
for _, a := range vpnAddr {
|
||||||
for j := range l {
|
if _, ok := l[a]; ok {
|
||||||
if l[j] == vpnAddrs[i] {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,7 +752,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried := 0
|
queried := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for _, lhVpnAddr := range lighthouses {
|
for lhVpnAddr := range lighthouses {
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
@@ -885,7 +870,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated := 0
|
updated := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for _, lhVpnAddr := range lighthouses {
|
for lhVpnAddr := range lighthouses {
|
||||||
var v cert.Version
|
var v cert.Version
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
@@ -943,6 +928,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
V4AddrPorts: v4,
|
V4AddrPorts: v4,
|
||||||
V6AddrPorts: v6,
|
V6AddrPorts: v6,
|
||||||
RelayVpnAddrs: relays,
|
RelayVpnAddrs: relays,
|
||||||
|
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1062,19 +1048,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
useVersion := cert.Version1
|
||||||
if err != nil {
|
var queryVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
queryVpnAddr = netip.AddrFrom4(b)
|
||||||
|
useVersion = 1
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
useVersion = 2
|
||||||
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
||||||
Debugln("Dropping malformed HostQuery")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1083,6 +1069,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostQueryReply
|
n.Type = NebulaMeta_HostQueryReply
|
||||||
if useVersion == cert.Version1 {
|
if useVersion == cert.Version1 {
|
||||||
|
if !queryVpnAddr.Is4() {
|
||||||
|
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
||||||
|
}
|
||||||
b := queryVpnAddr.As4()
|
b := queryVpnAddr.As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
@@ -1127,9 +1116,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
//choosing to do nothing for now, but maybe we return an error?
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1188,17 +1176,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
if !r.Is4() {
|
if !r.Is4() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
b = r.As4()
|
b = r.As4()
|
||||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if v == cert.Version2 {
|
} else if v == cert.Version2 {
|
||||||
for _, r := range c.relay.relay {
|
for _, r := range c.relay.relay {
|
||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
//TODO: CERT-V2 don't panic
|
||||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
panic("unsupported version")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1208,16 +1198,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
lhh.lh.Lock()
|
||||||
if err != nil {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
var certVpnAddr netip.Addr
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
if n.Details.OldVpnAddr != 0 {
|
||||||
}
|
b := [4]byte{}
|
||||||
return
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
certVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
}
|
}
|
||||||
relays := n.Details.GetRelays()
|
relays := n.Details.GetRelays()
|
||||||
|
|
||||||
lhh.lh.Lock()
|
|
||||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
@@ -1242,24 +1234,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
|
||||||
var detailsVpnAddr netip.Addr
|
var detailsVpnAddr netip.Addr
|
||||||
var useVersion cert.Version
|
useVersion := cert.Version1
|
||||||
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
if n.Details.OldVpnAddr != 0 {
|
||||||
b := [4]byte{}
|
b := [4]byte{}
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
detailsVpnAddr = netip.AddrFrom4(b)
|
detailsVpnAddr = netip.AddrFrom4(b)
|
||||||
useVersion = cert.Version1
|
useVersion = cert.Version1
|
||||||
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
} else if n.Details.VpnAddr != nil {
|
||||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
useVersion = cert.Version2
|
useVersion = cert.Version2
|
||||||
} else {
|
} else {
|
||||||
detailsVpnAddr = netip.Addr{}
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
useVersion = cert.Version2
|
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
||||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
||||||
|
//Simple check that the host sent this not someone else
|
||||||
|
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||||
}
|
}
|
||||||
@@ -1273,24 +1268,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
|
|
||||||
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||||
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||||
am.Unlock()
|
am.Unlock()
|
||||||
|
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||||
switch useVersion {
|
|
||||||
case cert.Version1:
|
if useVersion == cert.Version1 {
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||||
case cert.Version2:
|
} else if useVersion == cert.Version2 {
|
||||||
// do nothing, we want to send a blank message
|
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
||||||
default:
|
} else {
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1308,20 +1303,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||||
|
//maybe one day we'll have a better idea, if it matters.
|
||||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
|
||||||
if err != nil {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
|
||||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
empty := []byte{0}
|
empty := []byte{0}
|
||||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
punch := func(vpnPeer netip.AddrPort) {
|
||||||
if !vpnPeer.IsValid() {
|
if !vpnPeer.IsValid() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1333,31 +1321,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
var logVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
logVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
}
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
punch(protoV4AddrPortToNetAddrPort(a))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
punch(protoV6AddrPortToNetAddrPort(a))
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
// of a double nat or other difficult scenario, this may help establish
|
// of a double nat or other difficult scenario, this may help establish
|
||||||
// a tunnel.
|
// a tunnel.
|
||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
|
var queryVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
queryVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
// managed by a channel.
|
// managed by a channel.
|
||||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1436,17 +1441,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
|||||||
}
|
}
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
|
||||||
if d.OldVpnAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
|
||||||
detailsVpnAddr := netip.AddrFrom4(b)
|
|
||||||
return detailsVpnAddr, cert.Version1, nil
|
|
||||||
} else if d.VpnAddr != nil {
|
|
||||||
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
|
||||||
return detailsVpnAddr, cert.Version2, nil
|
|
||||||
} else {
|
|
||||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
|
|||||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
|
|
||||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
|
||||||
|
|
||||||
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
|
||||||
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
|
||||||
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
lh1 := "10.128.0.2"
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
|
||||||
"hosts": []any{lh1},
|
|
||||||
"interval": "1s",
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
|
||||||
c.Settings["static_host_map"] = map[string]any{
|
|
||||||
lh1: []any{"1.1.1.1:4242"},
|
|
||||||
"10.128.0.42": []any{"1.2.3.4:4242"},
|
|
||||||
}
|
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
|
||||||
nt := new(bart.Lite)
|
|
||||||
nt.Insert(myVpnNet)
|
|
||||||
cs := &CertState{
|
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
|
||||||
myVpnNetworksTable: nt,
|
|
||||||
}
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lh.ifce = &mockEncWriter{}
|
|
||||||
|
|
||||||
//test that we actually have the static entry:
|
|
||||||
out := lh.Query(testStaticHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
|
||||||
out.Rebuild([]netip.Prefix{}) //why tho
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//bolt on a lower numbered primary IP
|
|
||||||
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
|
||||||
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
|
||||||
lh.addrMap[testSameHostNotStatic] = am
|
|
||||||
out.Rebuild([]netip.Prefix{}) //???
|
|
||||||
|
|
||||||
//test that we actually have the static entry:
|
|
||||||
out = lh.Query(testStaticHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
|
||||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//test that we actually have the static entry for BOTH:
|
|
||||||
out2 := lh.Query(testSameHostNotStatic)
|
|
||||||
assert.Same(t, out2, out)
|
|
||||||
|
|
||||||
//now do the delete
|
|
||||||
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
|
||||||
//verify
|
|
||||||
out = lh.Query(testSameHostNotStatic)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
if out == nil {
|
|
||||||
t.Fatal("expected non-nil query for the static host")
|
|
||||||
}
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
|
||||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLighthouse_DeletesWork(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
|
|
||||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
|
||||||
testHost := netip.MustParseAddr("10.128.0.42")
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
lh1 := "10.128.0.2"
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
|
||||||
"hosts": []any{lh1},
|
|
||||||
"interval": "1s",
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
|
||||||
c.Settings["static_host_map"] = map[string]any{
|
|
||||||
lh1: []any{"1.1.1.1:4242"},
|
|
||||||
}
|
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
|
||||||
nt := new(bart.Lite)
|
|
||||||
nt.Insert(myVpnNet)
|
|
||||||
cs := &CertState{
|
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
|
||||||
myVpnNetworksTable: nt,
|
|
||||||
}
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lh.ifce = &mockEncWriter{}
|
|
||||||
|
|
||||||
//insert the host
|
|
||||||
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
|
||||||
am.vpnAddrs = []netip.Addr{testHost}
|
|
||||||
am.addrs = []netip.AddrPort{myUdpAddr2}
|
|
||||||
lh.addrMap[testHost] = am
|
|
||||||
am.Rebuild([]netip.Prefix{}) //???
|
|
||||||
|
|
||||||
//test that we actually have the entry:
|
|
||||||
out := lh.Query(testHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testHost)
|
|
||||||
out.Rebuild([]netip.Prefix{}) //why tho
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//now do the delete
|
|
||||||
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
|
||||||
//verify
|
|
||||||
out = lh.Query(testHost)
|
|
||||||
assert.Nil(t, out)
|
|
||||||
}
|
|
||||||
|
|||||||
36
main.go
36
main.go
@@ -164,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
@@ -185,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
hostMap := NewHostMapFromConfig(l, c)
|
hostMap := NewHostMapFromConfig(l, c)
|
||||||
punchy := NewPunchyFromConfig(l, c)
|
punchy := NewPunchyFromConfig(l, c)
|
||||||
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
|
||||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||||
@@ -221,15 +220,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchCfg := BatchConfig{
|
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
||||||
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
|
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
||||||
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
|
|
||||||
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
|
|
||||||
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
|
|
||||||
MaxPendingPackets: c.GetInt("batch.max_pending_packets", 0),
|
|
||||||
MaxPendingBytes: c.GetInt("batch.max_pending_bytes", 0),
|
|
||||||
MaxSendBuffersPerChan: c.GetInt("batch.max_send_buffers_per_routine", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
@@ -239,8 +231,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
ServeDns: serveDns,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
connectionManager: connManager,
|
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
|
checkInterval: time.Second * time.Duration(checkInterval),
|
||||||
|
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
|
||||||
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
||||||
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
@@ -251,8 +244,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
BatchConfig: batchCfg,
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,14 +288,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
f: ifce,
|
ifce,
|
||||||
l: l,
|
l,
|
||||||
ctx: ctx,
|
ctx,
|
||||||
cancel: cancel,
|
cancel,
|
||||||
sshStart: sshStart,
|
sshStart,
|
||||||
statsStart: statsStart,
|
statsStart,
|
||||||
dnsStart: dnsStart,
|
dnsStart,
|
||||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
lightHouse.StartUpdateWorker,
|
||||||
connectionManagerStart: connManager.Start,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
36
outside.go
36
outside.go
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//f.l.Error("in packet ", h)
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
// Pull the Roaming parts up here, and return in all call paths.
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
@@ -213,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
@@ -245,7 +245,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
@@ -255,18 +254,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
// If connectionstate exists and the replay protector allows, process packet
|
||||||
if ci == nil {
|
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||||
|
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||||
}
|
return false
|
||||||
|
} else {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
|
||||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -315,11 +312,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||||
next := 0
|
next := 0
|
||||||
for {
|
for {
|
||||||
if protoAt >= dataLen {
|
if dataLen < offset {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
proto := layers.IPProtocol(data[protoAt])
|
|
||||||
|
|
||||||
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
|
//fmt.Println(proto, protoAt)
|
||||||
switch proto {
|
switch proto {
|
||||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
@@ -367,7 +365,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
case layers.IPProtocolAH:
|
case layers.IPProtocolAH:
|
||||||
// Auth headers, used by IPSec, have a different meaning for header length
|
// Auth headers, used by IPSec, have a different meaning for header length
|
||||||
if dataLen <= offset+1 {
|
if dataLen < offset+1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -375,7 +373,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
// Normal ipv6 header length processing
|
// Normal ipv6 header length processing
|
||||||
if dataLen <= offset+1 {
|
if dataLen < offset+1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -471,7 +469,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -501,7 +499,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
@@ -540,6 +538,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hostinfo.RecvErrorExceeded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -117,45 +117,6 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A v6 packet with a hop-by-hop extension
|
|
||||||
// ICMPv6 Payload (Echo Request)
|
|
||||||
icmpLayer := layers.ICMPv6{
|
|
||||||
TypeCode: layers.ICMPv6TypeEchoRequest,
|
|
||||||
}
|
|
||||||
// Hop-by-Hop Extension Header
|
|
||||||
hopOption := layers.IPv6HopByHopOption{}
|
|
||||||
hopOption.OptionData = []byte{0, 0, 0, 0}
|
|
||||||
hopByHop := layers.IPv6HopByHop{}
|
|
||||||
hopByHop.Options = append(hopByHop.Options, &hopOption)
|
|
||||||
|
|
||||||
ip = layers.IPv6{
|
|
||||||
Version: 6,
|
|
||||||
HopLimit: 128,
|
|
||||||
NextHeader: layers.IPProtocolIPv6Destination,
|
|
||||||
SrcIP: net.IPv6linklocalallrouters,
|
|
||||||
DstIP: net.IPv6linklocalallnodes,
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer.Clear()
|
|
||||||
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: false,
|
|
||||||
FixLengths: true,
|
|
||||||
}, &ip, &hopByHop, &icmpLayer)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
// Ensure buffer length checks during parsing with the next 2 tests.
|
|
||||||
|
|
||||||
// A full IPv6 header and 1 byte in the first extension, but missing
|
|
||||||
// the length byte.
|
|
||||||
err = newPacket(buffer.Bytes()[:41], true, p)
|
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
|
||||||
|
|
||||||
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
|
||||||
// next layer, missing length byte
|
|
||||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
Version: 6,
|
Version: 6,
|
||||||
@@ -327,10 +288,6 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.Equal(t, uint16(22), p.LocalPort)
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// Ensure buffer bounds checking during processing
|
|
||||||
err = newPacket(b[:41], true, p)
|
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
|
||||||
|
|
||||||
// Invalid AH header
|
// Invalid AH header
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -71,13 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
pLen := 128
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
pLen = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -553,3 +554,13 @@ func (t *tun) Name() string {
|
|||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
pLen := 128
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
pLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
@@ -20,18 +22,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||||
FIODGNAME = 0x80106678
|
FIODGNAME = 0x80106678
|
||||||
TUNSIFMODE = 0x8004745e
|
|
||||||
TUNSIFHEAD = 0x80047460
|
|
||||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
|
||||||
IN6_IFF_NODAD = 0x0020
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fiodgnameArg struct {
|
type fiodgnameArg struct {
|
||||||
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
type ifreqDestroy struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
pad [16]byte
|
pad [16]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Flags uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
MTU int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
devFd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
io.ReadWriteCloser
|
||||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := make([]byte, 4)
|
|
||||||
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, err
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head []byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.devFd >= 0 {
|
if t.ReadWriteCloser != nil {
|
||||||
err := syscall.Close(t.devFd)
|
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
t.l.WithError(err).Error("Error closing device")
|
|
||||||
}
|
}
|
||||||
t.devFd = -1
|
|
||||||
|
|
||||||
c := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
|
||||||
defer close(c)
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
if err == nil {
|
|
||||||
defer syscall.Close(s)
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
return err
|
||||||
}
|
}
|
||||||
}()
|
defer syscall.Close(s)
|
||||||
|
|
||||||
// wait up to 1 second so we start blocking at the ioctl
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
select {
|
|
||||||
case <-c:
|
// Destroy the interface
|
||||||
case <-time.After(1 * time.Second):
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
}
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
|||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var fd int
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the name of the interface
|
rawConn, err := file.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var name [16]byte
|
var name [16]byte
|
||||||
|
var ctrlErr error
|
||||||
|
rawConn.Control(func(fd uintptr) {
|
||||||
|
// Read the name of the interface
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||||
|
})
|
||||||
if ctrlErr == nil {
|
|
||||||
// set broadcast mode and multicast
|
|
||||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr == nil {
|
|
||||||
// turn on link-layer mode, to support ipv6
|
|
||||||
ifhead := uint32(1)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr != nil {
|
if ctrlErr != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if ifName != deviceName {
|
if ifName != deviceName {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
s, err := syscall.Socket(
|
||||||
|
syscall.AF_INET,
|
||||||
|
syscall.SOCK_DGRAM,
|
||||||
|
syscall.IPPROTO_IP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
|
ReadWriteCloser: file,
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
devFd: fd,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
if cidr.Addr().Is4() {
|
var err error
|
||||||
ifr := ifreqAlias4{
|
// TODO use syscalls instead of exec.Command
|
||||||
Name: t.deviceBytes(),
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
Addr: unix.RawSockaddrInet4{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Len: unix.SizeofSockaddrInet4,
|
if err = cmd.Run(); err != nil {
|
||||||
Family: unix.AF_INET,
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
},
|
|
||||||
DstAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: getBroadcast(cidr).As4(),
|
|
||||||
},
|
|
||||||
MaskAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
},
|
|
||||||
VHid: 0,
|
|
||||||
}
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||||
ifr := ifreqAlias6{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Name: t.deviceBytes(),
|
if err = cmd.Run(); err != nil {
|
||||||
Addr: unix.RawSockaddrInet6{
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
},
|
|
||||||
PrefixMask: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
},
|
|
||||||
Lifetime: addrLifetime{
|
|
||||||
Expire: 0,
|
|
||||||
Preferred: 0,
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
},
|
|
||||||
Flags: IN6_IFF_NODAD,
|
|
||||||
}
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
// Setup our default MTU
|
|
||||||
err := t.setMTU()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if linkAddr == nil {
|
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
|
||||||
}
|
|
||||||
t.linkAddr = linkAddr
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) setMTU() error {
|
|
||||||
// Set the MTU on the device
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
|
||||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -462,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -500,144 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func flipBytes(b []byte) []byte {
|
|
||||||
for i := 0; i < len(b); i++ {
|
|
||||||
b[i] ^= 0xFF
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func orBytes(a []byte, b []byte) []byte {
|
|
||||||
ret := make([]byte, len(a))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
ret[i] = a[i] | b[i]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
|
||||||
broadcast, _ := netip.AddrFromSlice(
|
|
||||||
orBytes(
|
|
||||||
cidr.Addr().AsSlice(),
|
|
||||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_ADD,
|
|
||||||
Flags: unix.RTF_UP,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
fmt.Println("DOING CHANGE")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLinkAddr Gets the link address for the interface of the given name
|
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range msgs {
|
|
||||||
switch m := m.(type) {
|
|
||||||
case *netroute.InterfaceMessage:
|
|
||||||
if m.Name == name {
|
|
||||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
|
||||||
if ok {
|
|
||||||
return sa, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -293,6 +293,7 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||||||
|
|
||||||
//add all new addresses
|
//add all new addresses
|
||||||
for i := range newAddrs {
|
for i := range newAddrs {
|
||||||
|
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -360,11 +361,6 @@ func (t *tun) Activate() error {
|
|||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
const modeNone = 1
|
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
if err = t.addIPs(link); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -642,11 +638,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Dst == nil {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
type Packet struct {
|
|
||||||
Payload []byte
|
|
||||||
Addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *Packet {
|
|
||||||
return &Packet{Payload: make([]byte, 9001)}
|
|
||||||
}
|
|
||||||
@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
|||||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the parameters which include the peer's public key
|
// Set up the parameters which include the peer's public key
|
||||||
|
|||||||
5
pki.go
5
pki.go
@@ -173,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
|
|
||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
|
//TODO: CERT-V2 newState needs a stringer that does json
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||||
} else {
|
} else {
|
||||||
@@ -358,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1.Networks()[0] != v2.Networks()[0] {
|
//TODO: CERT-V2 make sure v2 has v1s address
|
||||||
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.initiatingVersion = dv
|
cs.initiatingVersion = dv
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ type RemoteList struct {
|
|||||||
// The full list of vpn addresses assigned to this host
|
// The full list of vpn addresses assigned to this host
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||||
addrs []netip.AddrPort
|
addrs []netip.AddrPort
|
||||||
|
|
||||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||||
@@ -202,9 +202,7 @@ type RemoteList struct {
|
|||||||
cache map[netip.Addr]*cache
|
cache map[netip.Addr]*cache
|
||||||
|
|
||||||
hr *hostnamesResults
|
hr *hostnamesResults
|
||||||
|
shouldAdd func(netip.Addr) bool
|
||||||
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
|
||||||
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
|
||||||
|
|
||||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||||
// They should not be tried again during a handshake
|
// They should not be tried again during a handshake
|
||||||
@@ -215,7 +213,7 @@ type RemoteList struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteList creates a new empty RemoteList
|
// NewRemoteList creates a new empty RemoteList
|
||||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||||
r := &RemoteList{
|
r := &RemoteList{
|
||||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||||
addrs: make([]netip.AddrPort, 0),
|
addrs: make([]netip.AddrPort, 0),
|
||||||
@@ -370,15 +368,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
|
||||||
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
|
||||||
r.Lock()
|
|
||||||
r.badRemotes = nil
|
|
||||||
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
|
||||||
copy(r.vpnAddrs, vpnAddrs)
|
|
||||||
r.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||||
func (r *RemoteList) ResetBlockedRemotes() {
|
func (r *RemoteList) ResetBlockedRemotes() {
|
||||||
r.Lock()
|
r.Lock()
|
||||||
@@ -588,7 +577,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
|
|
||||||
dnsAddrs := r.hr.GetAddrs()
|
dnsAddrs := r.hr.GetAddrs()
|
||||||
for _, addr := range dnsAddrs {
|
for _, addr := range dnsAddrs {
|
||||||
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||||
if !r.unlockedIsBad(addr) {
|
if !r.unlockedIsBad(addr) {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,10 +44,7 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
wait, err := control.Start()
|
control.Start()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -144,12 +141,6 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the nebula wait function to the group
|
|
||||||
eg.Go(func() error {
|
|
||||||
wait()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
15
udp/conn.go
15
udp/conn.go
@@ -16,18 +16,12 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader) error
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
WriteBatch(pkts []BatchPacket) (int, error)
|
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type BatchPacket struct {
|
|
||||||
Payload []byte
|
|
||||||
Addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
type NoopConn struct{}
|
type NoopConn struct{}
|
||||||
|
|
||||||
func (NoopConn) Rebind() error {
|
func (NoopConn) Rebind() error {
|
||||||
@@ -36,15 +30,12 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) error {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package udp
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")
|
|
||||||
@@ -3,62 +3,20 @@
|
|||||||
|
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
|
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StdConn struct {
|
|
||||||
*net.UDPConn
|
|
||||||
isV4 bool
|
|
||||||
sysFd uintptr
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Conn = &StdConn{}
|
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if uc, ok := pc.(*net.UDPConn); ok {
|
|
||||||
c := &StdConn{UDPConn: uc, l: l}
|
|
||||||
|
|
||||||
rc, err := uc.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to open udp socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rc.Control(func(fd uintptr) {
|
|
||||||
c.sysFd = fd
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get udp fd: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
la, err := c.LocalAddr()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.isV4 = la.Addr().Is4()
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
@@ -85,130 +43,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:linkname sendto golang.org/x/sys/unix.sendto
|
func (u *GenericConn) Rebind() error {
|
||||||
//go:noescape
|
rc, err := u.UDPConn.SyscallConn()
|
||||||
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
|
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|
||||||
var sa unsafe.Pointer
|
|
||||||
var addrLen int32
|
|
||||||
|
|
||||||
if u.isV4 {
|
|
||||||
if ap.Addr().Is6() {
|
|
||||||
return ErrInvalidIPv6RemoteForSocket
|
|
||||||
}
|
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet6
|
|
||||||
rsa.Family = unix.AF_INET6
|
|
||||||
rsa.Addr = ap.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
|
||||||
sa = unsafe.Pointer(&rsa)
|
|
||||||
addrLen = syscall.SizeofSockaddrInet4
|
|
||||||
} else {
|
|
||||||
var rsa unix.RawSockaddrInet6
|
|
||||||
rsa.Family = unix.AF_INET6
|
|
||||||
rsa.Addr = ap.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
|
||||||
sa = unsafe.Pointer(&rsa)
|
|
||||||
addrLen = syscall.SizeofSockaddrInet6
|
|
||||||
}
|
|
||||||
|
|
||||||
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
|
|
||||||
// See https://github.com/golang/go/issues/73919
|
|
||||||
for {
|
|
||||||
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
|
|
||||||
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
|
|
||||||
if err == nil {
|
|
||||||
// Written, get out before the error handling
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EINTR) {
|
|
||||||
// Write was interrupted, retry
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EAGAIN) {
|
|
||||||
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EBADF) {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
return &net.OpError{Op: "sendto", Err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
|
||||||
sent := 0
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
sent++
|
|
||||||
}
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|
||||||
a := u.UDPConn.LocalAddr()
|
|
||||||
|
|
||||||
switch v := a.(type) {
|
|
||||||
case *net.UDPAddr:
|
|
||||||
addr, ok := netip.AddrFromSlice(v.IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
|
||||||
// TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|
||||||
// No UDP stats for non-linux
|
|
||||||
return func() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
|
||||||
buffer := make([]byte, MTU)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Just read one packet at a time
|
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
return rc.Control(func(fd uintptr) {
|
||||||
continue
|
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
|
||||||
}
|
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
|
||||||
var err error
|
|
||||||
if u.isV4 {
|
|
||||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
|
|
||||||
} else {
|
|
||||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Error("Failed to rebind udp socket")
|
u.l.WithError(err).Error("Failed to rebind udp socket")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
//go:build (!linux || android) && !e2e_testing && !darwin
|
//go:build (!linux || android) && !e2e_testing
|
||||||
// +build !linux android
|
// +build !linux android
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
// +build !darwin
|
|
||||||
|
|
||||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||||
// means it can be used on platforms like Darwin and Windows.
|
// means it can be used on platforms like Darwin and Windows.
|
||||||
@@ -42,17 +41,6 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
|
||||||
sent := 0
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
sent++
|
|
||||||
}
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
@@ -82,14 +70,15 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
func (u *GenericConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
|||||||
609
udp/udp_linux.go
609
udp/udp_linux.go
@@ -5,13 +5,10 @@ package udp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -20,40 +17,19 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultGSOMaxSegments = 128
|
|
||||||
defaultGSOFlushTimeout = 80 * time.Microsecond
|
|
||||||
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
|
|
||||||
maxGSOBatchBytes = 0xFFFF
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errGSOFallback = errors.New("udp gso fallback")
|
|
||||||
errGSODisabled = errors.New("udp gso disabled")
|
|
||||||
)
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
|
}
|
||||||
|
|
||||||
enableGRO bool
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
enableGSO bool
|
ip4 := ip.To4()
|
||||||
|
if ip4 != nil {
|
||||||
gsoMu sync.Mutex
|
return ip4, true
|
||||||
gsoBuf []byte
|
}
|
||||||
gsoAddr netip.AddrPort
|
return ip, false
|
||||||
gsoSegSize int
|
|
||||||
gsoSegments int
|
|
||||||
gsoMaxSegments int
|
|
||||||
gsoMaxBytes int
|
|
||||||
gsoFlushTimeout time.Duration
|
|
||||||
gsoTimer *time.Timer
|
|
||||||
|
|
||||||
groBufSize int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -79,11 +55,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a read timeout
|
|
||||||
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
var sa unix.Sockaddr
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
sa4 := &unix.SockaddrInet4{Port: port}
|
sa4 := &unix.SockaddrInet4{Port: port}
|
||||||
@@ -98,16 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StdConn{
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
sysFd: fd,
|
|
||||||
isV4: ip.Is4(),
|
|
||||||
l: l,
|
|
||||||
batch: batch,
|
|
||||||
gsoMaxSegments: defaultGSOMaxSegments,
|
|
||||||
gsoMaxBytes: MTU * defaultGSOMaxSegments,
|
|
||||||
gsoFlushTimeout: defaultGSOFlushTimeout,
|
|
||||||
groBufSize: MTU,
|
|
||||||
}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -156,46 +118,20 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var (
|
var ip netip.Addr
|
||||||
ip netip.Addr
|
|
||||||
controls [][]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
bufSize := u.readBufferSize()
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
|
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
desired := u.readBufferSize()
|
|
||||||
if len(buffers) == 0 || cap(buffers[0]) < desired {
|
|
||||||
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
|
|
||||||
controls = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.enableGRO {
|
|
||||||
if controls == nil {
|
|
||||||
controls = make([][]byte, len(msgs))
|
|
||||||
for i := range controls {
|
|
||||||
controls[i] = make([]byte, unix.CmsgSpace(4))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range msgs {
|
|
||||||
setRawMessageControl(&msgs[i], controls[i])
|
|
||||||
}
|
|
||||||
} else if controls != nil {
|
|
||||||
for i := range msgs {
|
|
||||||
setRawMessageControl(&msgs[i], nil)
|
|
||||||
}
|
|
||||||
controls = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
@@ -205,80 +141,9 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
payload := buffers[i][:msgs[i].Len]
|
|
||||||
|
|
||||||
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
ctrlLen := getRawMessageControlLen(&msgs[i])
|
|
||||||
msgFlags := getRawMessageFlags(&msgs[i])
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "recv",
|
|
||||||
"payload_len": len(payload),
|
|
||||||
"ctrl_len": ctrlLen,
|
|
||||||
"msg_flags": msgFlags,
|
|
||||||
}).Debug("gro batch data")
|
|
||||||
if controls != nil && ctrlLen > 0 {
|
|
||||||
maxDump := ctrlLen
|
|
||||||
if maxDump > 16 {
|
|
||||||
maxDump = 16
|
|
||||||
}
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "control-bytes",
|
|
||||||
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
|
|
||||||
"datalen": ctrlLen,
|
|
||||||
}).Debug("gro control dump")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sawControl := false
|
|
||||||
if controls != nil {
|
|
||||||
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
|
|
||||||
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
|
|
||||||
sawControl = true
|
|
||||||
if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "control",
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
"payloadLen": len(payload),
|
|
||||||
}).Debug("gro control parsed")
|
|
||||||
}
|
|
||||||
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
|
|
||||||
if segSize > 0 && segSize < len(payload) {
|
|
||||||
if u.emitGROSegments(r, addr, payload, segSize) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.enableGRO && len(payload) > MTU {
|
|
||||||
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "fallback",
|
|
||||||
"payload_len": len(payload),
|
|
||||||
}).Debug("gro control missing; splitting payload by MTU")
|
|
||||||
}
|
|
||||||
if u.emitGROSegments(r, addr, payload, MTU) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r(addr, payload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) readBufferSize() int {
|
|
||||||
if u.enableGRO && u.groBufSize > MTU {
|
|
||||||
return u.groBufSize
|
|
||||||
}
|
|
||||||
return MTU
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
@@ -294,9 +159,6 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,9 +180,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -329,132 +188,12 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||||
if u.enableGSO && ip.IsValid() {
|
|
||||||
if err := u.queueGSOPacket(b, ip); err == nil {
|
|
||||||
return nil
|
|
||||||
} else if !errors.Is(err, errGSOFallback) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
return u.writeTo4(b, ip)
|
return u.writeTo4(b, ip)
|
||||||
}
|
}
|
||||||
return u.writeTo6(b, ip)
|
return u.writeTo6(b, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
|
||||||
if len(pkts) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs := make([]rawMessage, 0, len(pkts))
|
|
||||||
iovs := make([]iovec, 0, len(pkts))
|
|
||||||
names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
|
|
||||||
|
|
||||||
sent := 0
|
|
||||||
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if len(pkt.Payload) == 0 {
|
|
||||||
sent++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.enableGSO && pkt.Addr.IsValid() {
|
|
||||||
if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
|
|
||||||
sent++
|
|
||||||
continue
|
|
||||||
} else if !errors.Is(err, errGSOFallback) {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pkt.Addr.IsValid() {
|
|
||||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
sent++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, rawMessage{})
|
|
||||||
iovs = append(iovs, iovec{})
|
|
||||||
names = append(names, [unix.SizeofSockaddrInet6]byte{})
|
|
||||||
|
|
||||||
idx := len(msgs) - 1
|
|
||||||
msg := &msgs[idx]
|
|
||||||
iov := &iovs[idx]
|
|
||||||
name := &names[idx]
|
|
||||||
|
|
||||||
setIovecSlice(iov, pkt.Payload)
|
|
||||||
msg.Hdr.Iov = iov
|
|
||||||
msg.Hdr.Iovlen = 1
|
|
||||||
setRawMessageControl(msg, nil)
|
|
||||||
msg.Hdr.Flags = 0
|
|
||||||
|
|
||||||
nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
|
|
||||||
if err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
msg.Hdr.Name = &name[0]
|
|
||||||
msg.Hdr.Namelen = nameLen
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(msgs) == 0 {
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
for offset < len(msgs) {
|
|
||||||
n, _, errno := unix.Syscall6(
|
|
||||||
unix.SYS_SENDMMSG,
|
|
||||||
uintptr(u.sysFd),
|
|
||||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
|
||||||
uintptr(len(msgs)-offset),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
if errno == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
offset += int(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sent + len(msgs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
|
||||||
if u.isV4 {
|
|
||||||
if !addr.Addr().Is4() {
|
|
||||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
|
||||||
}
|
|
||||||
var sa unix.RawSockaddrInet4
|
|
||||||
sa.Family = unix.AF_INET
|
|
||||||
sa.Addr = addr.Addr().As4()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet4
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.RawSockaddrInet6
|
|
||||||
sa.Family = unix.AF_INET6
|
|
||||||
sa.Addr = addr.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet6
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||||
var rsa unix.RawSockaddrInet6
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET6
|
rsa.Family = unix.AF_INET6
|
||||||
@@ -555,94 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
u.configureGRO(c)
|
|
||||||
u.configureGSO(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGRO(c *config.C) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
enable := c.GetBool("listen.enable_gro", true)
|
|
||||||
if enable == u.enableGRO {
|
|
||||||
if enable {
|
|
||||||
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
|
|
||||||
u.setGROBufferSize(size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.enableGRO = true
|
|
||||||
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
|
|
||||||
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
|
||||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
||||||
}
|
|
||||||
u.enableGRO = false
|
|
||||||
u.groBufSize = MTU
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGSO(c *config.C) {
|
|
||||||
enable := c.GetBool("listen.enable_gso", true)
|
|
||||||
if !enable {
|
|
||||||
u.disableGSO()
|
|
||||||
} else {
|
|
||||||
u.enableGSO = true
|
|
||||||
}
|
|
||||||
|
|
||||||
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
|
||||||
if segments < 1 {
|
|
||||||
segments = 1
|
|
||||||
}
|
|
||||||
u.gsoMaxSegments = segments
|
|
||||||
|
|
||||||
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = MTU * segments
|
|
||||||
}
|
|
||||||
if maxBytes > maxGSOBatchBytes {
|
|
||||||
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
|
|
||||||
maxBytes = maxGSOBatchBytes
|
|
||||||
}
|
|
||||||
u.gsoMaxBytes = maxBytes
|
|
||||||
|
|
||||||
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
|
||||||
if timeout < 0 {
|
|
||||||
timeout = 0
|
|
||||||
}
|
|
||||||
u.gsoFlushTimeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) setGROBufferSize(size int) {
|
|
||||||
if size < MTU {
|
|
||||||
size = defaultGROReadBufferSize
|
|
||||||
}
|
|
||||||
if size > maxGSOBatchBytes {
|
|
||||||
size = maxGSOBatchBytes
|
|
||||||
}
|
|
||||||
u.groBufSize = size
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) disableGSO() {
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
u.enableGSO = false
|
|
||||||
_ = u.flushGSOlocked()
|
|
||||||
u.gsoBuf = nil
|
|
||||||
u.gsoSegments = 0
|
|
||||||
u.gsoSegSize = 0
|
|
||||||
u.stopGSOTimerLocked()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
@@ -654,239 +305,7 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
|
|
||||||
if err := u.flushGSOlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return errGSOFallback
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoSegments == 0 {
|
|
||||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
|
||||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
}
|
|
||||||
u.gsoAddr = addr
|
|
||||||
u.gsoSegSize = len(b)
|
|
||||||
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
|
|
||||||
if err := u.flushGSOlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
|
||||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
}
|
|
||||||
u.gsoAddr = addr
|
|
||||||
u.gsoSegSize = len(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
|
|
||||||
if err := u.flushGSOlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
|
||||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
}
|
|
||||||
u.gsoAddr = addr
|
|
||||||
u.gsoSegSize = len(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoBuf = append(u.gsoBuf, b...)
|
|
||||||
u.gsoSegments++
|
|
||||||
|
|
||||||
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
|
|
||||||
return u.flushGSOlocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
u.scheduleGSOFlushLocked()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) flushGSOlocked() error {
|
|
||||||
if u.gsoSegments == 0 {
|
|
||||||
u.stopGSOTimerLocked()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := append([]byte(nil), u.gsoBuf...)
|
|
||||||
addr := u.gsoAddr
|
|
||||||
segSize := u.gsoSegSize
|
|
||||||
|
|
||||||
u.gsoBuf = u.gsoBuf[:0]
|
|
||||||
u.gsoSegments = 0
|
|
||||||
u.gsoSegSize = 0
|
|
||||||
u.stopGSOTimerLocked()
|
|
||||||
|
|
||||||
if segSize <= 0 {
|
|
||||||
return errGSOFallback
|
|
||||||
}
|
|
||||||
|
|
||||||
err := u.sendSegmented(payload, addr, segSize)
|
|
||||||
if errors.Is(err, errGSODisabled) {
|
|
||||||
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
|
|
||||||
u.enableGSO = false
|
|
||||||
return u.sendSegmentsIndividually(payload, addr, segSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if len(payload) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
control := make([]byte, unix.CmsgSpace(2))
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
setCmsgLen(hdr, unix.CmsgLen(2))
|
|
||||||
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
|
||||||
if addr.Addr().Is4() {
|
|
||||||
var sa4 unix.SockaddrInet4
|
|
||||||
sa4.Port = int(addr.Port())
|
|
||||||
sa4.Addr = addr.Addr().As4()
|
|
||||||
sa = &sa4
|
|
||||||
} else {
|
|
||||||
var sa6 unix.SockaddrInet6
|
|
||||||
sa6.Port = int(addr.Port())
|
|
||||||
sa6.Addr = addr.Addr().As16()
|
|
||||||
sa = &sa6
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
|
|
||||||
return errGSODisabled
|
|
||||||
}
|
|
||||||
return &net.OpError{Op: "sendmsg", Err: err}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if segSize <= 0 {
|
|
||||||
return errGSOFallback
|
|
||||||
}
|
|
||||||
|
|
||||||
for offset := 0; offset < len(buf); offset += segSize {
|
|
||||||
end := offset + segSize
|
|
||||||
if end > len(buf) {
|
|
||||||
end = len(buf)
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if u.isV4 {
|
|
||||||
err = u.writeTo4(buf[offset:end], addr)
|
|
||||||
} else {
|
|
||||||
err = u.writeTo6(buf[offset:end], addr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) scheduleGSOFlushLocked() {
|
|
||||||
if u.gsoTimer == nil {
|
|
||||||
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.gsoTimer.Reset(u.gsoFlushTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) stopGSOTimerLocked() {
|
|
||||||
if u.gsoTimer != nil {
|
|
||||||
u.gsoTimer.Stop()
|
|
||||||
u.gsoTimer = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) gsoFlushTimer() {
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
_ = u.flushGSOlocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGROControl(control []byte) (int, int) {
|
|
||||||
if len(control) == 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(control)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
|
||||||
segCount := 0
|
|
||||||
if len(c.Data) >= 4 {
|
|
||||||
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
|
||||||
}
|
|
||||||
return segSize, segCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
|
|
||||||
if segSize <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for offset := 0; offset < len(payload); offset += segSize {
|
|
||||||
end := offset + segSize
|
|
||||||
if end > len(payload) {
|
|
||||||
end = len(payload)
|
|
||||||
}
|
|
||||||
segment := make([]byte, end-offset)
|
|
||||||
copy(segment, payload[offset:end])
|
|
||||||
r(addr, segment)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeGROSegSize(segSize, segCount, total int) int {
|
|
||||||
if segSize <= 0 || total <= 0 {
|
|
||||||
return segSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if segSize > total && segCount > 0 {
|
|
||||||
segSize = total / segCount
|
|
||||||
if segSize == 0 {
|
|
||||||
segSize = total
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if segCount <= 1 && segSize > 0 && total > segSize {
|
|
||||||
calculated := total / segSize
|
|
||||||
if calculated <= 1 {
|
|
||||||
calculated = (total + segSize - 1) / segSize
|
|
||||||
}
|
|
||||||
if calculated > 1 {
|
|
||||||
segCount = calculated
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if segSize > MTU {
|
|
||||||
return MTU
|
|
||||||
}
|
|
||||||
|
|
||||||
return segSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Close() error {
|
func (u *StdConn) Close() error {
|
||||||
u.disableGSO()
|
|
||||||
return syscall.Close(u.sysFd)
|
return syscall.Close(u.sysFd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,16 +30,13 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
if bufSize <= 0 {
|
|
||||||
bufSize = MTU
|
|
||||||
}
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, bufSize)
|
buffers[i] = make([]byte, MTU)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -55,35 +52,3 @@ func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte
|
|||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|
||||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
msg.Hdr.Control = nil
|
|
||||||
msg.Hdr.Controllen = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Control = &buf[0]
|
|
||||||
msg.Hdr.Controllen = uint32(len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageControlLen(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Controllen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageFlags(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Flags)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|
||||||
h.Len = uint32(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setIovecSlice(iov *iovec, b []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
iov.Base = nil
|
|
||||||
iov.Len = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
iov.Base = &b[0]
|
|
||||||
iov.Len = uint32(len(b))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -33,16 +33,13 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
if bufSize <= 0 {
|
|
||||||
bufSize = MTU
|
|
||||||
}
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, bufSize)
|
buffers[i] = make([]byte, MTU)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -58,35 +55,3 @@ func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte
|
|||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|
||||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
msg.Hdr.Control = nil
|
|
||||||
msg.Hdr.Controllen = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Control = &buf[0]
|
|
||||||
msg.Hdr.Controllen = uint64(len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageControlLen(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Controllen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageFlags(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Flags)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|
||||||
h.Len = uint64(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setIovecSlice(iov *iovec, b []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
iov.Base = nil
|
|
||||||
iov.Len = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
iov.Base = &b[0]
|
|
||||||
iov.Len = uint64(len(b))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -92,25 +92,6 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
// Enable v4 for this socket
|
// Enable v4 for this socket
|
||||||
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
|
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
|
||||||
|
|
||||||
// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
|
|
||||||
// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
|
|
||||||
// the UDP receive error returns with these ioctl calls.
|
|
||||||
ret := uint32(0)
|
|
||||||
flag := uint32(0)
|
|
||||||
size := uint32(unsafe.Sizeof(flag))
|
|
||||||
err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ret = 0
|
|
||||||
flag = 0
|
|
||||||
size = uint32(unsafe.Sizeof(flag))
|
|
||||||
SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
|
|
||||||
err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = u.rx.Open()
|
err = u.rx.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -134,20 +115,16 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
func (u *RIOConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.receive(buffer)
|
n, rua, err := u.receive(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||||
}
|
}
|
||||||
@@ -304,17 +281,6 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
|||||||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
|
||||||
sent := 0
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
sent++
|
|
||||||
}
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
sa, err := windows.Getsockname(u.sock)
|
sa, err := windows.Getsockname(u.sock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -106,17 +106,6 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
|
||||||
sent := 0
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
|
||||||
return sent, err
|
|
||||||
}
|
|
||||||
sent++
|
|
||||||
}
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) {
|
func (u *TesterConn) ListenOut(r EncReader) {
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
|
|||||||
Reference in New Issue
Block a user