Compare commits

...

25 Commits

Author SHA1 Message Date
Ryan
9253f36a3c tweak defaults and turn on gsogro on linux by default 2025-11-06 13:34:58 -05:00
Ryan
c9a695c2bf try with sendmmsg merged back 2025-11-06 10:56:53 -05:00
Ryan
2c6f81c224 config tweaks for batching 2025-11-06 10:01:20 -05:00
Ryan
ad37749c5e add batching of packets 2025-11-06 09:42:13 -05:00
Ryan
a0f8cb2098 works properly 2025-11-05 22:09:06 -05:00
Ryan
d18d1aea67 first 2025-11-05 20:34:02 -05:00
Ryan
f5ff534671 make it work with dnclient 2025-11-05 19:25:32 -05:00
Nate Brown
2ea8a72d5c dunno 2025-10-05 23:23:30 -05:00
Nate Brown
663232e1fc Testing the concept 2025-10-05 23:23:10 -05:00
Nate Brown
2f48529e8b Cleanup and note more work 2025-10-05 23:23:08 -05:00
Nate Brown
f3e1ad64cd Try the timeout 2025-10-05 23:22:29 -05:00
Nate Brown
1d8112a329 Revert "More playing" way too much garbage emitted
This reverts commit fa098c551a.
2025-10-05 23:22:29 -05:00
Nate Brown
31eea0cc94 More playing 2025-10-05 23:22:29 -05:00
Nate Brown
dbba4a4c77 Playing 2025-10-05 23:22:29 -05:00
Nate Brown
194fde45da non-blocking io for linux 2025-10-05 23:22:27 -05:00
Nate Brown
f46b83f2c4 Remove more os.Exit calls and give a more reliable wait for stop function 2025-10-05 23:20:43 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
dependabot[bot]
071589f7c7 Bump actions/setup-go from 5 to 6 (#1469)
* Bump actions/setup-go from 5 to 6

Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Hardcode the last one to go v1.25

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-10-02 00:05:12 -05:00
Jack Doan
f1e992f6dd don't require a detailsVpnAddr in a HostUpdateNotification (#1472)
* don't require a detailsVpnAddr in a HostUpdateNotification

* don't send our own addr on HostUpdateNotification for v2
2025-09-29 13:43:12 -05:00
Jack Doan
1ea5f776d7 update to go 1.25, use the cool new ECDSA key marshalling functions (#1483)
* update to go 1.25, use the cool new ECDSA key marshalling functions

* bonk the runners

* actually bump go.mod

* bump golangci-lint
2025-09-29 13:02:25 -05:00
Henry Graham
4cdeb284ef Set CKA_VALUE_LEN attribute in DeriveNoise (#1482) 2025-09-25 13:24:52 -05:00
Jack Doan
5cccd39465 update RemoteList.vpnAddrs when we complete a handshake (#1467) 2025-09-10 09:44:25 -05:00
Jack Doan
8196c22b5a store lighthouses as a slice (#1473)
* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
2025-09-10 09:43:25 -05:00
42 changed files with 2173 additions and 329 deletions

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Install goimports - name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -35,9 +35,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -68,9 +68,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Import certificates - name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version-file: 'go.mod' go-version: '1.25'
check-latest: true check-latest: true
- name: add hashicorp source - name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: build - name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -34,7 +34,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v8 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.1 version: v2.5
- name: Test - name: Test
run: make test run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.22' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build nebula - name: Build nebula
@@ -117,7 +117,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v8 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.1 version: v2.5
- name: Test - name: Test
run: make test run: make test

View File

@@ -5,6 +5,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
type Bits struct { type Bits struct {
length uint64 length uint64
current uint64 current uint64
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Error("rejected a packet (top) %d %d\n", b.current, i)
return false return false
} }

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations. // PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature. // Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve Curve() Curve

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey return c.details.publicKey
} }
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte { func (c *certificateV1) Signature() []byte {
return c.signature return c.signature
} }
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -1,6 +1,7 @@
package cert package cert
import ( import (
"crypto/ed25519"
"fmt" "fmt"
"net/netip" "net/netip"
"testing" "testing"
@@ -13,6 +14,7 @@ import (
) )
func TestCertificateV1_Marshal(t *testing.T) { func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) { func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey return c.publicKey
} }
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte { func (c *certificateV2) Signature() []byte {
return c.signature return c.signature
} }
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -15,6 +15,7 @@ import (
) )
func TestCertificateV2_Marshal(t *testing.T) { func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) { func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{ nc := certificateV2{
details: detailsV2{ details: detailsV2{

View File

@@ -1,25 +1,34 @@
package cert package cert
import ( import (
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"time"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
const ( const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE" CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2" CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
) )
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +60,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
} }
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
@@ -62,6 +81,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b) k, r := pem.Decode(b)
if k == nil { if k == nil {
@@ -73,7 +105,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner: case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32 expectedLen = 32
curve = Curve_CURVE25519 curve = Curve_CURVE25519
case P256PublicKeyBanner: case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed // Uncompressed
expectedLen = 65 expectedLen = 65
curve = Curve_P256 curve = Curve_P256
@@ -108,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non // UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure // consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalPublicKeyFromPEM(t *testing.T) { func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY----- -----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalX25519PublicKey(t *testing.T) { func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY----- -----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`) `)
shortKey := []byte(`# A short key shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`) -END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32) assert.Len(t, k, 32)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65) assert.Len(t, k, 65)

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"time" "time"
) )
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
} }
return t.SignWith(signer, curve, sp) return t.SignWith(signer, curve, sp)
case Curve_P256: case Curve_P256:
pk := &ecdsa.PrivateKey{ pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
PublicKey: ecdsa.PublicKey{ if err != nil {
Curve: elliptic.P256(), return nil, err
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
} }
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) { sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA // We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1 // - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -65,8 +65,16 @@ func main() {
} }
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
ctrl.ShutdownBlock() if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -3,6 +3,9 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
_ "net/http/pprof"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -58,10 +61,22 @@ func main() {
os.Exit(1) os.Exit(1)
} }
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l) notifyReady(l)
ctrl.ShutdownBlock() wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{}, addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10), queryChan: make(chan netip.Addr, 10),
} }
lighthouses := map[netip.Addr]struct{}{} lighthouses := []netip.Addr{}
staticList := map[netip.Addr]struct{}{} staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses) lh.lighthouses.Store(&lighthouses)
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey return d.publicKey
} }
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte { func (d *dummyCert) Signature() []byte {
return d.signature return d.signature
} }

View File

@@ -13,7 +13,9 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 1024 // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 4096
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

View File

@@ -2,9 +2,11 @@ package nebula
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
) )
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
} }
type Control struct { type Control struct {
stateLock sync.Mutex
state RunState
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
ctx context.Context ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call.
func (c *Control) Start() { // The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface // Activate the interface
c.f.activate() err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created. // Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil { if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
} }
// Start reading packets. // Start reading packets.
c.f.run() c.state = Started
c.stateLock.Unlock()
return c.f.run(c.ctx)
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
} }
func (c *Control) Context() context.Context { func (c *Control) Context() context.Context {
return c.ctx return c.ctx
} }
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete // Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() { func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from // Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down. // being created while we're shutting them all down.
c.cancel() c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil { if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed") c.l.WithError(err).Error("Close interface failed")
} }
c.l.Info("Goodbye") c.state = Stopped
} }
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -132,6 +132,13 @@ listen:
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload # default is 64, does not support reload
#batch: 64 #batch: 64
# Control batching between UDP and TUN pipelines
#batch:
# inbound_size: 32 # packets to queue from UDP before handing to workers
# outbound_size: 32 # packets to queue from TUN before handing to workers
# flush_interval: 50us # flush partially filled batches after this duration
# max_outstanding: 1028 # batches buffered per routine on each channel
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel # Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default) # Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide # Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide

4
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.23.0 go 1.25
toolchain go1.24.1
require ( require (
dario.cat/mergo v1.0.2 dario.cat/mergo v1.0.2

View File

@@ -459,7 +459,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo) f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return return
} }
@@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
} }
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
return false return false

View File

@@ -11,19 +11,19 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
} }
return return false
} }
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast { if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
return return false
} }
} }
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Otherwise, drop. On linux, we should never see these packets - Linux // Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula addr to the nebula addr through the loopback device. // routes packets from the nebula addr to the nebula addr through the loopback device.
return return false
} }
// Ignore multicast packets // Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return return false
} }
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
@@ -59,18 +59,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
} }
return return false
} }
if !ready { if !ready {
return return false
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
}
} else {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l). hostinfo.logger(f.l).
@@ -78,7 +78,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
WithField("reason", dropReason). WithField("reason", dropReason).
Debugln("dropping outbound packet") Debugln("dropping outbound packet")
} }
} return false
} }
func (f *Interface) rejectInside(packet []byte, out []byte, q int) { func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return return
} }
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) _ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return return
} }
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) _ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) _ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
} }
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) _ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
} }
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
@@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
f.connectionManager.RelayUsed(relay.LocalIndex) f.connectionManager.RelayUsed(relay.LocalIndex)
} }
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { // sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
// when the payload has been handed to a caller-provided queue (meaning the
// caller is responsible for flushing it later).
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
if ci.eKey == nil { if ci.eKey == nil {
return return false
} }
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out fullOut := out
@@ -380,22 +383,28 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return return false
} }
if remote.IsValid() { dest := remote
err = f.writers[q].WriteTo(out, remote) if !dest.IsValid() {
dest = hostinfo.remote
}
if dest.IsValid() {
if queue != nil {
queue(dest, len(out))
return true
}
err = f.writers[q].WriteTo(out, dest)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", dest).Error("Failed to write outgoing packet")
} }
} else if hostinfo.remote.IsValid() { return false
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
} else {
// Try to send via a relay // Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() { for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
@@ -407,5 +416,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break break
} }
}
return false
} }

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -18,10 +18,22 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const (
mtu = 9001
inboundBatchSizeDefault = 128
outboundBatchSizeDefault = 64
batchFlushIntervalDefault = 12 * time.Microsecond
maxOutstandingBatchesDefault = 8
sendBatchSizeDefault = 64
maxPendingPacketsDefault = 32
maxPendingBytesDefault = 64 * 1024
maxSendBufPerRoutineDefault = 16
)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -47,9 +59,20 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
BatchConfig BatchConfig
l *logrus.Logger l *logrus.Logger
} }
type BatchConfig struct {
InboundBatchSize int
OutboundBatchSize int
FlushInterval time.Duration
MaxOutstandingPerChan int
MaxPendingPackets int
MaxPendingBytes int
MaxSendBuffersPerChan int
}
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside udp.Conn outside udp.Conn
@@ -87,12 +110,165 @@ type Interface struct {
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger l *logrus.Logger
inPool sync.Pool
inbound []chan *packetBatch
outPool sync.Pool
outbound []chan *outboundBatch
packetBatchPool sync.Pool
outboundBatchPool sync.Pool
sendPool sync.Pool
sendBufCache [][]*[]byte
sendBatchSize int
inboundBatchSize int
outboundBatchSize int
batchFlushInterval time.Duration
maxOutstandingPerChan int
maxPendingPackets int
maxPendingBytes int
maxSendBufPerRoutine int
}
type outboundSend struct {
buf *[]byte
length int
addr netip.AddrPort
}
type packetBatch struct {
packets []*packet.Packet
}
func newPacketBatch(capacity int) *packetBatch {
return &packetBatch{
packets: make([]*packet.Packet, 0, capacity),
}
}
func (b *packetBatch) add(p *packet.Packet) {
b.packets = append(b.packets, p)
}
func (b *packetBatch) reset() {
for i := range b.packets {
b.packets[i] = nil
}
b.packets = b.packets[:0]
}
func (f *Interface) getPacketBatch() *packetBatch {
if v := f.packetBatchPool.Get(); v != nil {
b := v.(*packetBatch)
b.reset()
return b
}
return newPacketBatch(f.inboundBatchSize)
}
func (f *Interface) releasePacketBatch(b *packetBatch) {
b.reset()
f.packetBatchPool.Put(b)
}
type outboundBatch struct {
payloads []*[]byte
}
func newOutboundBatch(capacity int) *outboundBatch {
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
}
func (b *outboundBatch) add(buf *[]byte) {
b.payloads = append(b.payloads, buf)
}
func (b *outboundBatch) reset() {
for i := range b.payloads {
b.payloads[i] = nil
}
b.payloads = b.payloads[:0]
}
func (f *Interface) getOutboundBatch() *outboundBatch {
if v := f.outboundBatchPool.Get(); v != nil {
b := v.(*outboundBatch)
b.reset()
return b
}
return newOutboundBatch(f.outboundBatchSize)
}
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
b.reset()
f.outboundBatchPool.Put(b)
}
func (f *Interface) getSendBuffer(q int) *[]byte {
cache := f.sendBufCache[q]
if n := len(cache); n > 0 {
buf := cache[n-1]
f.sendBufCache[q] = cache[:n-1]
*buf = (*buf)[:0]
return buf
}
if v := f.sendPool.Get(); v != nil {
buf := v.(*[]byte)
*buf = (*buf)[:0]
return buf
}
b := make([]byte, mtu)
return &b
}
func (f *Interface) releaseSendBuffer(q int, buf *[]byte) {
if buf == nil {
return
}
*buf = (*buf)[:0]
cache := f.sendBufCache[q]
if len(cache) < f.maxSendBufPerRoutine {
f.sendBufCache[q] = append(cache, buf)
return
}
f.sendPool.Put(buf)
}
func (f *Interface) flushSendQueue(q int, pending *[]outboundSend, pendingBytes *int) {
if len(*pending) == 0 {
return
}
batch := make([]udp.BatchPacket, len(*pending))
for i, entry := range *pending {
batch[i] = udp.BatchPacket{
Payload: (*entry.buf)[:entry.length],
Addr: entry.addr,
}
}
sent, err := f.writers[q].WriteBatch(batch)
if err != nil {
f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
}
for _, entry := range *pending {
f.releaseSendBuffer(q, entry.buf)
}
*pending = (*pending)[:0]
if pendingBytes != nil {
*pendingBytes = 0
}
} }
type EncWriter interface { type EncWriter interface {
@@ -162,6 +338,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
} }
cs := c.pki.getCertState() cs := c.pki.getCertState()
bc := c.BatchConfig
if bc.InboundBatchSize <= 0 {
bc.InboundBatchSize = inboundBatchSizeDefault
}
if bc.OutboundBatchSize <= 0 {
bc.OutboundBatchSize = outboundBatchSizeDefault
}
if bc.FlushInterval <= 0 {
bc.FlushInterval = batchFlushIntervalDefault
}
if bc.MaxOutstandingPerChan <= 0 {
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
}
if bc.MaxPendingPackets <= 0 {
bc.MaxPendingPackets = maxPendingPacketsDefault
}
if bc.MaxPendingBytes <= 0 {
bc.MaxPendingBytes = maxPendingBytesDefault
}
if bc.MaxSendBuffersPerChan <= 0 {
bc.MaxSendBuffersPerChan = maxSendBufPerRoutineDefault
}
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
@@ -194,9 +393,49 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
inbound: make([]chan *packetBatch, c.routines),
outbound: make([]chan *outboundBatch, c.routines),
l: c.l, l: c.l,
inboundBatchSize: bc.InboundBatchSize,
outboundBatchSize: bc.OutboundBatchSize,
batchFlushInterval: bc.FlushInterval,
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
maxPendingPackets: bc.MaxPendingPackets,
maxPendingBytes: bc.MaxPendingBytes,
maxSendBufPerRoutine: bc.MaxSendBuffersPerChan,
sendBatchSize: bc.OutboundBatchSize,
} }
for i := 0; i < c.routines; i++ {
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
}
ifce.inPool = sync.Pool{New: func() any {
return packet.New()
}}
ifce.outPool = sync.Pool{New: func() any {
t := make([]byte, mtu)
return &t
}}
ifce.packetBatchPool = sync.Pool{New: func() any {
return newPacketBatch(ifce.inboundBatchSize)
}}
ifce.outboundBatchPool = sync.Pool{New: func() any {
return newOutboundBatch(ifce.outboundBatchSize)
}}
ifce.sendPool = sync.Pool{New: func() any {
buf := make([]byte, mtu)
return &buf
}}
ifce.sendBufCache = make([][]*[]byte, c.routines)
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -209,7 +448,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any // activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However, // other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called. // the interface isn't going to process anything until run() is called.
func (f *Interface) activate() { func (f *Interface) activate() error {
// actually turn on tun dev // actually turn on tun dev
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
@@ -230,33 +469,44 @@ func (f *Interface) activate() {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
if err != nil { if err != nil {
f.l.Fatal(err) return err
} }
} }
f.readers[i] = reader f.readers[i] = reader
} }
if err := f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
f.inside.Close() f.inside.Close()
f.l.Fatal(err) return err
} }
return nil
} }
func (f *Interface) run() { func (f *Interface) run(c context.Context) (func(), error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
// Launch n queues to read packets from udp
f.wg.Add(1)
go f.listenOut(i) go f.listenOut(i)
}
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ { f.wg.Add(1)
go f.listenIn(f.readers[i], i) go f.listenIn(f.readers[i], i)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerIn(i, c)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerOut(i, c)
} }
return f.wg.Wait, nil
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
@@ -264,41 +514,176 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) batch := f.getPacketBatch()
lhh := f.lightHouse.NewRequestHandler() lastFlush := time.Now()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { flush := func(force bool) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) if len(batch.packets) == 0 {
if force {
f.releasePacketBatch(batch)
}
return
}
f.inbound[i] <- batch
batch = f.getPacketBatch()
lastFlush = time.Now()
}
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet)
p.Payload = p.Payload[:mtu]
copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr
batch.add(p)
if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
flush(false)
}
}) })
if len(batch.packets) > 0 {
f.inbound[i] <- batch
} else {
f.releasePacketBatch(batch)
}
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu) batch := f.getOutboundBatch()
out := make([]byte, mtu) lastFlush := time.Now()
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) flush := func(force bool) {
if len(batch.payloads) == 0 {
for { if force {
n, err := reader.Read(packet) f.releaseOutboundBatch(batch)
if err != nil { }
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
} }
f.l.WithError(err).Error("Error while reading outbound packet") f.outbound[i] <- batch
// This only seems to happen when something fatal happens to the fd, so exit. batch = f.getOutboundBatch()
os.Exit(2) lastFlush = time.Now()
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) for {
p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu]
n, err := reader.Read(*p)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
}
break
}
*p = (*p)[:n]
batch.add(p)
if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
flush(false)
}
}
if len(batch.payloads) > 0 {
f.outbound[i] <- batch
} else {
f.releaseOutboundBatch(batch)
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket2 := &firewall.Packet{}
nb2 := make([]byte, 12, 12)
result2 := make([]byte, mtu)
h := &header.H{}
for {
select {
case batch := <-f.inbound[i]:
for _, p := range batch.packets {
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
}
f.releasePacketBatch(batch)
case <-ctx.Done():
f.wg.Done()
return
}
}
}
func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12)
pending := make([]outboundSend, 0, f.sendBatchSize)
pendingBytes := 0
maxPendingPackets := f.maxPendingPackets
if maxPendingPackets <= 0 {
maxPendingPackets = f.sendBatchSize
}
maxPendingBytes := f.maxPendingBytes
if maxPendingBytes <= 0 {
maxPendingBytes = f.sendBatchSize * mtu
}
for {
select {
case batch := <-f.outbound[i]:
for _, data := range batch.payloads {
sendBuf := f.getSendBuffer(i)
buf := (*sendBuf)[:0]
queue := func(addr netip.AddrPort, length int) {
if len(pending) >= maxPendingPackets || pendingBytes+length > maxPendingBytes {
f.flushSendQueue(i, &pending, &pendingBytes)
}
pending = append(pending, outboundSend{
buf: sendBuf,
length: length,
addr: addr,
})
pendingBytes += length
if len(pending) >= f.sendBatchSize || pendingBytes >= maxPendingBytes {
f.flushSendQueue(i, &pending, &pendingBytes)
}
}
sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
if !sent {
f.releaseSendBuffer(i, sendBuf)
}
*data = (*data)[:mtu]
f.outPool.Put(data)
}
f.releaseOutboundBatch(batch)
if len(pending) > 0 {
f.flushSendQueue(i, &pending, &pendingBytes)
}
case <-ctx.Done():
if len(pending) > 0 {
f.flushSendQueue(i, &pending, &pendingBytes)
}
f.wg.Done()
return
}
} }
} }
@@ -451,6 +836,7 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers { for _, u := range f.writers {
err := u.Close() err := u.Close()
if err != nil { if err != nil {
@@ -458,6 +844,13 @@ func (f *Interface) Close() error {
} }
} }
// Release the tun device // Release the tun readers
return f.inside.Close() for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
} }

View File

@@ -57,7 +57,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}] staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[[]netip.Addr]
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
@@ -108,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l, l: l,
} }
lighthouses := make(map[netip.Addr]struct{}) lighthouses := make([]netip.Addr, 0)
h.lighthouses.Store(&lighthouses) h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{}) staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList) h.staticList.Store(&staticList)
@@ -144,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load() return *lh.staticList.Load()
} }
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { func (lh *LightHouse) GetLighthouses() []netip.Addr {
return *lh.lighthouses.Load() return *lh.lighthouses.Load()
} }
@@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
if initial || c.HasChanged("lighthouse.hosts") { if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[netip.Addr]struct{}) lhList, err := lh.parseLighthouses(c)
err := lh.parseLighthouses(c, lhMap)
if err != nil { if err != nil {
return err return err
} }
lh.lighthouses.Store(&lhMap) lh.lighthouses.Store(&lhList)
if !initial { if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed") lh.l.Info("lighthouse.hosts has changed")
@@ -347,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
lhs := c.GetStringSlice("lighthouse.hosts", []string{}) lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 { if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
} }
out := make([]netip.Addr, len(lhs))
for i, host := range lhs { for i, host := range lhs {
addr, err := netip.ParseAddr(host) addr, err := netip.ParseAddr(host)
if err != nil { if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
} }
lhMap[addr] = struct{}{} out[i] = addr
} }
if !lh.amLighthouse && len(lhMap) == 0 { if !lh.amLighthouse && len(out) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
} }
staticList := lh.GetStaticHostList() staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap { for i := range out {
if _, ok := staticList[lhAddr]; !ok { if _, ok := staticList[out[i]]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
} }
} }
return nil return out, nil
} }
func getStaticMapCadence(c *config.C) (time.Duration, error) { func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -487,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock() lh.Lock()
defer lh.Unlock() defer lh.Unlock()
// Add an entry if we don't already have one // Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
} }
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -570,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr) am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() { for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
continue continue
} }
switch { switch {
@@ -645,18 +645,17 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
} }
} }
//TODO lighthouse.remote_allow_ranges is almost certainly broken in a multiple-address-per-cert scenario am := NewRemoteList(allAddrs, lh.shouldAdd)
am := NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
for _, addr := range allAddrs { for _, addr := range allAddrs {
lh.addrMap[addr] = am lh.addrMap[addr] = am
} }
return am return am
} }
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow") Trace("remoteAllowList.Allow")
} }
if !allow { if !allow {
@@ -711,17 +710,24 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
} }
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
_, ok := lh.GetLighthouses()[vpnAddr] l := lh.GetLighthouses()
return ok for i := range l {
if l[i] == vpnAddr {
return true
}
}
return false
} }
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
l := lh.GetLighthouses() l := lh.GetLighthouses()
for _, a := range vpnAddr { for i := range vpnAddrs {
if _, ok := l[a]; ok { for j := range l {
if l[j] == vpnAddrs[i] {
return true return true
} }
} }
}
return false return false
} }
@@ -761,7 +767,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0 queried := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
v = hi.ConnectionState.myCert.Version() v = hi.ConnectionState.myCert.Version()
@@ -879,7 +885,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0 updated := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
var v cert.Version var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
@@ -937,7 +943,6 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4, V4AddrPorts: v4,
V6AddrPorts: v6, V6AddrPorts: v6,
RelayVpnAddrs: relays, RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
}, },
} }
@@ -1060,7 +1065,16 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil { if err != nil {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
} }
return return
} }
@@ -1069,9 +1083,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 { if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4() b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else { } else {
@@ -1231,16 +1242,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return return
} }
detailsVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() // not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
if err != nil { var detailsVpnAddr netip.Addr
if lhh.l.Level >= logrus.DebugLevel { var useVersion cert.Version
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostUpdateNotification") if n.Details.OldVpnAddr != 0 { //v1 always sets this field
} b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
} }
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? //Simple check that the host sent this not someone else, if detailsVpnAddr is filled
//Simple check that the host sent this not someone else if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
} }
@@ -1254,24 +1273,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays) am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock() am.Unlock()
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
if useVersion == cert.Version1 { case cert.Version1:
if !fromVpnAddrs[0].Is4() { if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return return
} }
vpnAddrB := fromVpnAddrs[0].As4() vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
} else if useVersion == cert.Version2 { case cert.Version2:
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) // do nothing, we want to send a blank message
} else { default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return return
} }
@@ -1289,7 +1308,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr, //It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return return
} }

31
main.go
View File

@@ -164,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
@@ -221,6 +221,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
batchCfg := BatchConfig{
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
MaxPendingPackets: c.GetInt("batch.max_pending_packets", 0),
MaxPendingBytes: c.GetInt("batch.max_pending_bytes", 0),
MaxSendBuffersPerChan: c.GetInt("batch.max_send_buffers_per_routine", 0),
}
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
@@ -242,6 +252,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
BatchConfig: batchCfg,
l: l, l: l,
} }
@@ -284,14 +295,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
return &Control{ return &Control{
ifce, f: ifce,
l, l: l,
ctx, ctx: ctx,
cancel, cancel: cancel,
sshStart, sshStart: sshStart,
statsStart, statsStart: statsStart,
dnsStart, dnsStart: dnsStart,
lightHouse.StartUpdateWorker, lighthouseStart: lightHouse.StartUpdateWorker,
connManager.Start, connectionManagerStart: connManager.Start,
}, nil }, nil
} }

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //f.l.Error("in packet ", h)
if ip.IsValid() { if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
return return
} }
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
return false return false
} }

View File

@@ -1,6 +1,7 @@
package overlay package overlay
import ( import (
"net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed return removed
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -554,13 +553,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -10,11 +10,9 @@ import (
"io" "io"
"io/fs" "io/fs"
"net/netip" "net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
) )
const ( const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD // FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678 FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
) )
type fiodgnameArg struct { type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
} }
type ifreqRename struct { type ifreqRename struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
} }
type ifreqDestroy struct { type ifreqDestroy struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
pad [16]byte pad [16]byte
} }
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct { type tun struct {
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
devFd int
}
io.ReadWriteCloser func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.ReadWriteCloser != nil { if t.devFd >= 0 {
if err := t.ReadWriteCloser.Close(); err != nil { err := syscall.Close(t.devFd)
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil { if err != nil {
return err t.l.WithError(err).Error("Error closing device")
} }
t.devFd = -1
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s) defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()} ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err }
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
} }
return nil return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var fd int
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName != "" { if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
} }
if errors.Is(err, fs.ErrNotExist) || deviceName == "" { if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it // If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
rawConn, err := file.SyscallConn() // Read the name of the interface
if err != nil { var name [16]byte
return nil, fmt.Errorf("SyscallConn: %v", err) arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
} }
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil { if ctrlErr != nil {
return nil, err return nil, err
} }
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now // If the name doesn't match the desired interface name, rename it now
if ifName != deviceName { if ifName != deviceName {
s, err := syscall.Socket( s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
t := &tun{ t := &tun{
ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
devFd: fd,
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error if cidr.Addr().Is4() {
// TODO use syscalls instead of exec.Command ifr := ifreqAlias4{
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) Name: t.deviceBytes(),
t.l.Debug("command: ", cmd.String()) Addr: unix.RawSockaddrInet4{
if err = cmd.Run(); err != nil { Len: unix.SizeofSockaddrInet4,
return fmt.Errorf("failed to run 'ifconfig': %s", err) Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) if cidr.Addr().Is6() {
t.l.Debug("command: ", cmd.String()) ifr := ifreqAlias6{
if err = cmd.Run(); err != nil { Name: t.deviceBytes(),
return fmt.Errorf("failed to run 'route add': %s", err) Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) return fmt.Errorf("unknown address type %v", cidr)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i]) err := t.addIp(t.vpnNetworks[i])
if err != nil { if err != nil {
return err return err
} }
} }
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) err := addRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
return retErr return retErr
} }
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) err := delRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +500,144 @@ func (t *tun) deviceBytes() (o [16]byte) {
} }
return return
} }
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

12
packet/packet.go Normal file
View File

@@ -0,0 +1,12 @@
package packet
import "net/netip"
type Packet struct {
Payload []byte
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{Payload: make([]byte, 9001)}
}

View File

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
} }
// Set up the parameters which include the peer's public key // Set up the parameters which include the peer's public key

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host // The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of underlay addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays. // A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -202,7 +202,9 @@ type RemoteList struct {
cache map[netip.Addr]*cache cache map[netip.Addr]*cache
hr *hostnamesResults hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
r := &RemoteList{ r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)), vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0), addrs: make([]netip.AddrPort, 0),
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c return c
} }
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list // ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() { func (r *RemoteList) ResetBlockedRemotes() {
r.Lock() r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs() dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if !r.unlockedIsBad(addr) { if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr) addrs = append(addrs, addr)
} }

View File

@@ -44,7 +44,10 @@ type Service struct {
} }
func New(control *nebula.Control) (*Service, error) { func New(control *nebula.Control) (*Service, error) {
control.Start() wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context() ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
} }
}) })
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil return &s, nil
} }

View File

@@ -16,12 +16,18 @@ type EncReader func(
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
WriteBatch(pkts []BatchPacket) (int, error)
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error
} }
type BatchPacket struct {
Payload []byte
Addr netip.AddrPort
}
type NoopConn struct{} type NoopConn struct{}
func (NoopConn) Rebind() error { func (NoopConn) Rebind() error {
@@ -30,12 +36,15 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) error {
return return nil
} }
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {
return return
} }

View File

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
} }
} }
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
@@ -165,7 +176,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {} return func() {}
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
@@ -174,14 +185,17 @@ func (u *StdConn) ListenOut(r EncReader) {
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return return err
} }
u.l.WithError(err).Error("unexpected udp socket receive error") u.l.WithError(err).Error("unexpected udp socket receive error")
continue
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
} }
return nil
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {

View File

@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
return err return err
} }
func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
@@ -71,15 +82,14 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])

View File

@@ -5,10 +5,13 @@ package udp
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
@@ -17,19 +20,40 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
const (
defaultGSOMaxSegments = 128
defaultGSOFlushTimeout = 80 * time.Microsecond
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
maxGSOBatchBytes = 0xFFFF
)
var (
errGSOFallback = errors.New("udp gso fallback")
errGSODisabled = errors.New("udp gso disabled")
)
type StdConn struct { type StdConn struct {
sysFd int sysFd int
isV4 bool isV4 bool
l *logrus.Logger l *logrus.Logger
batch int batch int
}
func maybeIPV4(ip net.IP) (net.IP, bool) { enableGRO bool
ip4 := ip.To4() enableGSO bool
if ip4 != nil {
return ip4, true gsoMu sync.Mutex
} gsoBuf []byte
return ip, false gsoAddr netip.AddrPort
gsoSegSize int
gsoSegments int
gsoMaxSegments int
gsoMaxBytes int
gsoFlushTimeout time.Duration
gsoTimer *time.Timer
groBufSize int
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -55,6 +79,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
} }
} }
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr var sa unix.Sockaddr
if ip.Is4() { if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port} sa4 := &unix.SockaddrInet4{Port: port}
@@ -69,7 +98,16 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
} }
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err return &StdConn{
sysFd: fd,
isV4: ip.Is4(),
l: l,
batch: batch,
gsoMaxSegments: defaultGSOMaxSegments,
gsoMaxBytes: MTU * defaultGSOMaxSegments,
gsoFlushTimeout: defaultGSOFlushTimeout,
groBufSize: MTU,
}, err
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
@@ -118,20 +156,46 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) error {
var ip netip.Addr var (
ip netip.Addr
controls [][]byte
)
msgs, buffers, names := u.PrepareRawMessages(u.batch) bufSize := u.readBufferSize()
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
} }
for { for {
desired := u.readBufferSize()
if len(buffers) == 0 || cap(buffers[0]) < desired {
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
controls = nil
}
if u.enableGRO {
if controls == nil {
controls = make([][]byte, len(msgs))
for i := range controls {
controls[i] = make([]byte, unix.CmsgSpace(4))
}
}
for i := range msgs {
setRawMessageControl(&msgs[i], controls[i])
}
} else if controls != nil {
for i := range msgs {
setRawMessageControl(&msgs[i], nil)
}
controls = nil
}
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -141,9 +205,80 @@ func (u *StdConn) ListenOut(r EncReader) {
} else { } else {
ip, _ = netip.AddrFromSlice(names[i][8:24]) ip, _ = netip.AddrFromSlice(names[i][8:24])
} }
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payload := buffers[i][:msgs[i].Len]
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
ctrlLen := getRawMessageControlLen(&msgs[i])
msgFlags := getRawMessageFlags(&msgs[i])
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "recv",
"payload_len": len(payload),
"ctrl_len": ctrlLen,
"msg_flags": msgFlags,
}).Debug("gro batch data")
if controls != nil && ctrlLen > 0 {
maxDump := ctrlLen
if maxDump > 16 {
maxDump = 16
}
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control-bytes",
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
"datalen": ctrlLen,
}).Debug("gro control dump")
} }
} }
sawControl := false
if controls != nil {
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
sawControl = true
if u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control",
"seg_size": segSize,
"seg_count": segCount,
"payloadLen": len(payload),
}).Debug("gro control parsed")
}
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
if segSize > 0 && segSize < len(payload) {
if u.emitGROSegments(r, addr, payload, segSize) {
continue
}
}
}
}
}
if u.enableGRO && len(payload) > MTU {
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "fallback",
"payload_len": len(payload),
}).Debug("gro control missing; splitting payload by MTU")
}
if u.emitGROSegments(r, addr, payload, MTU) {
continue
}
}
r(addr, payload)
}
}
}
func (u *StdConn) readBufferSize() int {
if u.enableGRO && u.groBufSize > MTU {
return u.groBufSize
}
return MTU
} }
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
@@ -159,6 +294,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err} return 0, &net.OpError{Op: "recvmsg", Err: err}
} }
@@ -180,6 +318,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err} return 0, &net.OpError{Op: "recvmmsg", Err: err}
} }
@@ -188,12 +329,132 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
} }
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.enableGSO && ip.IsValid() {
if err := u.queueGSOPacket(b, ip); err == nil {
return nil
} else if !errors.Is(err, errGSOFallback) {
return err
}
}
if u.isV4 { if u.isV4 {
return u.writeTo4(b, ip) return u.writeTo4(b, ip)
} }
return u.writeTo6(b, ip) return u.writeTo6(b, ip)
} }
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
msgs := make([]rawMessage, 0, len(pkts))
iovs := make([]iovec, 0, len(pkts))
names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
sent := 0
for _, pkt := range pkts {
if len(pkt.Payload) == 0 {
sent++
continue
}
if u.enableGSO && pkt.Addr.IsValid() {
if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
sent++
continue
} else if !errors.Is(err, errGSOFallback) {
return sent, err
}
}
if !pkt.Addr.IsValid() {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
continue
}
msgs = append(msgs, rawMessage{})
iovs = append(iovs, iovec{})
names = append(names, [unix.SizeofSockaddrInet6]byte{})
idx := len(msgs) - 1
msg := &msgs[idx]
iov := &iovs[idx]
name := &names[idx]
setIovecSlice(iov, pkt.Payload)
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
setRawMessageControl(msg, nil)
msg.Hdr.Flags = 0
nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
if err != nil {
return sent, err
}
msg.Hdr.Name = &name[0]
msg.Hdr.Namelen = nameLen
}
if len(msgs) == 0 {
return sent, nil
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return sent + len(msgs), nil
}
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if u.isV4 {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
@@ -221,7 +482,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() { if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet4
@@ -294,6 +555,94 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark") u.l.WithError(err).Error("Failed to set listen.so_mark")
} }
} }
u.configureGRO(c)
u.configureGSO(c)
}
func (u *StdConn) configureGRO(c *config.C) {
if c == nil {
return
}
enable := c.GetBool("listen.enable_gro", true)
if enable == u.enableGRO {
if enable {
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
u.setGROBufferSize(size)
}
}
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
return
}
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
u.groBufSize = MTU
}
func (u *StdConn) configureGSO(c *config.C) {
enable := c.GetBool("listen.enable_gso", true)
if !enable {
u.disableGSO()
} else {
u.enableGSO = true
}
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if segments < 1 {
segments = 1
}
u.gsoMaxSegments = segments
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
if maxBytes <= 0 {
maxBytes = MTU * segments
}
if maxBytes > maxGSOBatchBytes {
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
maxBytes = maxGSOBatchBytes
}
u.gsoMaxBytes = maxBytes
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
if timeout < 0 {
timeout = 0
}
u.gsoFlushTimeout = timeout
}
func (u *StdConn) setGROBufferSize(size int) {
if size < MTU {
size = defaultGROReadBufferSize
}
if size > maxGSOBatchBytes {
size = maxGSOBatchBytes
}
u.groBufSize = size
}
func (u *StdConn) disableGSO() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
u.enableGSO = false
_ = u.flushGSOlocked()
u.gsoBuf = nil
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
} }
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
@@ -305,7 +654,239 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil return nil
} }
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
if len(b) == 0 {
return nil
}
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
return errGSOFallback
}
if u.gsoSegments == 0 {
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
u.gsoBuf = append(u.gsoBuf, b...)
u.gsoSegments++
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
return u.flushGSOlocked()
}
u.scheduleGSOFlushLocked()
return nil
}
func (u *StdConn) flushGSOlocked() error {
if u.gsoSegments == 0 {
u.stopGSOTimerLocked()
return nil
}
payload := append([]byte(nil), u.gsoBuf...)
addr := u.gsoAddr
segSize := u.gsoSegSize
u.gsoBuf = u.gsoBuf[:0]
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
if segSize <= 0 {
return errGSOFallback
}
err := u.sendSegmented(payload, addr, segSize)
if errors.Is(err, errGSODisabled) {
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
u.enableGSO = false
return u.sendSegmentsIndividually(payload, addr, segSize)
}
return err
}
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
if len(payload) == 0 {
return nil
}
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
var sa unix.Sockaddr
if addr.Addr().Is4() {
var sa4 unix.SockaddrInet4
sa4.Port = int(addr.Port())
sa4.Addr = addr.Addr().As4()
sa = &sa4
} else {
var sa6 unix.SockaddrInet6
sa6.Port = int(addr.Port())
sa6.Addr = addr.Addr().As16()
sa = &sa6
}
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
return errGSODisabled
}
return &net.OpError{Op: "sendmsg", Err: err}
}
return nil
}
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
if segSize <= 0 {
return errGSOFallback
}
for offset := 0; offset < len(buf); offset += segSize {
end := offset + segSize
if end > len(buf) {
end = len(buf)
}
var err error
if u.isV4 {
err = u.writeTo4(buf[offset:end], addr)
} else {
err = u.writeTo6(buf[offset:end], addr)
}
if err != nil {
return err
}
}
return nil
}
func (u *StdConn) scheduleGSOFlushLocked() {
if u.gsoTimer == nil {
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
return
}
u.gsoTimer.Reset(u.gsoFlushTimeout)
}
func (u *StdConn) stopGSOTimerLocked() {
if u.gsoTimer != nil {
u.gsoTimer.Stop()
u.gsoTimer = nil
}
}
func (u *StdConn) gsoFlushTimer() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
_ = u.flushGSOlocked()
}
func parseGROControl(control []byte) (int, int) {
if len(control) == 0 {
return 0, 0
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0, 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
segCount := 0
if len(c.Data) >= 4 {
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
}
return segSize, segCount
}
}
return 0, 0
}
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
if segSize <= 0 {
return false
}
for offset := 0; offset < len(payload); offset += segSize {
end := offset + segSize
if end > len(payload) {
end = len(payload)
}
segment := make([]byte, end-offset)
copy(segment, payload[offset:end])
r(addr, segment)
}
return true
}
func normalizeGROSegSize(segSize, segCount, total int) int {
if segSize <= 0 || total <= 0 {
return segSize
}
if segSize > total && segCount > 0 {
segSize = total / segCount
if segSize == 0 {
segSize = total
}
}
if segCount <= 1 && segSize > 0 && total > segSize {
calculated := total / segSize
if calculated <= 1 {
calculated = (total + segSize - 1) / segSize
}
if calculated > 1 {
segCount = calculated
}
}
if segSize > MTU {
return MTU
}
return segSize
}
func (u *StdConn) Close() error { func (u *StdConn) Close() error {
u.disableGSO()
return syscall.Close(u.sysFd) return syscall.Close(u.sysFd)
} }

View File

@@ -30,13 +30,16 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -52,3 +55,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names return msgs, buffers, names
} }
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint32(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint32(l)
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint32(len(b))
}

View File

@@ -33,13 +33,16 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -55,3 +58,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names return msgs, buffers, names
} }
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader) { func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock) sa, err := windows.Getsockname(u.sock)
if err != nil { if err != nil {

View File

@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *TesterConn) ListenOut(r EncReader) { func (u *TesterConn) ListenOut(r EncReader) {
for { for {
p, ok := <-u.RxPackets p, ok := <-u.RxPackets