Compare commits

..

26 Commits

Author SHA1 Message Date
JackDoan
e7423d39f9 cursed 2025-11-06 09:18:33 -06:00
JackDoan
befba57366 hmmm 2025-11-05 15:38:47 -06:00
Ryan
2d128a3254 add locking for stop crash 2025-11-05 11:58:25 -05:00
Ryan
c8980d34cf fixes 2025-11-05 10:54:08 -05:00
Ryan
98f264cf14 works well 2025-11-04 19:33:52 -05:00
Ryan
aa44f4c7c9 hmmmmmm it works i guess maybe 2025-11-04 16:08:31 -05:00
Ryan Huber
419157c407 passes traffic 2025-11-04 04:50:35 +00:00
Ryan Huber
0864852d33 updated bind 2025-11-04 04:39:07 +00:00
Ryan Huber
2b5aec9a18 updated udp 2025-11-04 04:34:59 +00:00
Ryan Huber
f0665bee20 pem.go restored 2025-11-04 04:32:22 +00:00
Ryan Huber
11da0baab1 quick fix 2025-11-04 04:21:27 +00:00
Ryan Huber
608904b9dd add new files for compat layer 2025-11-04 04:10:51 +00:00
Ryan Huber
fd1c52127f first try 2025-11-04 04:00:29 +00:00
Jack Doan
01909f4715 try to make certificate addition/removal reloadable in some cases (#1468)
* try to make certificate addition/removal reloadable in some cases

* very spicy change to respond to handshakes with cert versions we cannot match with a cert that we can indeed match

* even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

* make tryRehandshake easier to understand
2025-11-03 19:38:44 -06:00
Jack Doan
770147264d fix make bench (#1510) 2025-10-21 11:32:34 -05:00
dependabot[bot]
fa8c013b97 Bump github.com/miekg/dns from 1.1.65 to 1.1.68 (#1444)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.65 to 1.1.68.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.65...v1.1.68)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.68
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 16:41:51 -04:00
dependabot[bot]
2710f2af06 Bump github.com/kardianos/service from 1.2.2 to 1.2.4 (#1433)
Bumps [github.com/kardianos/service](https://github.com/kardianos/service) from 1.2.2 to 1.2.4.
- [Commits](https://github.com/kardianos/service/compare/v1.2.2...v1.2.4)

---
updated-dependencies:
- dependency-name: github.com/kardianos/service
  dependency-version: 1.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:58:15 -04:00
dependabot[bot]
ad6d3e6bac Bump the golang-x-dependencies group across 1 directory with 5 updates (#1409)
Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync).


Updates `golang.org/x/crypto` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/crypto/compare/v0.37.0...v0.38.0)

Updates `golang.org/x/net` from 0.39.0 to 0.40.0
- [Commits](https://github.com/golang/net/compare/v0.39.0...v0.40.0)

Updates `golang.org/x/sync` from 0.13.0 to 0.14.0
- [Commits](https://github.com/golang/sync/compare/v0.13.0...v0.14.0)

Updates `golang.org/x/sys` from 0.32.0 to 0.33.0
- [Commits](https://github.com/golang/sys/compare/v0.32.0...v0.33.0)

Updates `golang.org/x/term` from 0.31.0 to 0.32.0
- [Commits](https://github.com/golang/term/compare/v0.31.0...v0.32.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.38.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-version: 0.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-version: 0.14.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.33.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.32.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:54:38 -04:00
dependabot[bot]
2b0aa74e85 Bump github.com/prometheus/client_golang from 1.22.0 to 1.23.2 (#1470)
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.22.0 to 1.23.2.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.22.0...v1.23.2)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-version: 1.23.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:16:24 -04:00
dependabot[bot]
b126d88963 Bump github.com/gaissmai/bart from 0.20.4 to 0.25.0 (#1471)
Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.20.4 to 0.25.0.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.20.4...v0.25.0)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.25.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:15:07 -04:00
Nate Brown
45c1d3eab3 Support for multi proto tun device on OpenBSD (#1495) 2025-10-08 16:56:42 -05:00
Gary Guo
634181ba66 Fix incorrect CIDR construction in hostmap (#1493)
* Fix incorrect CIDR construction in hostmap

* Introduce a regression test for incorrect hostmap CIDR
2025-10-08 11:02:36 -05:00
Nate Brown
eb89839d13 Support for multi proto tun device on NetBSD (#1492) 2025-10-07 20:17:50 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
56 changed files with 6106 additions and 522 deletions

164
batch_pipeline.go Normal file
View File

@@ -0,0 +1,164 @@
package nebula
import (
"net/netip"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
// batchPipelines tracks whether the inside device can operate on packet batches
// and, if so, holds the shared packet pool sized for the virtio headroom and
// payload limits advertised by the device. It also owns the fan-in/fan-out
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
type batchPipelines struct {
enabled bool
inside overlay.BatchCapableDevice
headroom int
payloadCap int
pool *overlay.PacketPool
batchSize int
routines int
rxQueues []chan *overlay.Packet
txQueues []chan queuedDatagram
tunQueues []chan *overlay.Packet
}
type queuedDatagram struct {
packet *overlay.Packet
addr netip.AddrPort
}
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
if device == nil || routines <= 0 {
return
}
bcap, ok := device.(overlay.BatchCapableDevice)
if !ok {
return
}
headroom := bcap.BatchHeadroom()
payload := bcap.BatchPayloadCap()
if maxSegments < 1 {
maxSegments = 1
}
requiredPayload := udp.MTU * maxSegments
if payload < requiredPayload {
payload = requiredPayload
}
batchSize := bcap.BatchSize()
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
return
}
bp.enabled = true
bp.inside = bcap
bp.headroom = headroom
bp.payloadCap = payload
bp.batchSize = batchSize
bp.routines = routines
bp.pool = overlay.NewPacketPool(headroom, payload)
queueCap := batchSize * defaultBatchQueueDepthFactor
if queueDepth > 0 {
queueCap = queueDepth
}
if queueCap < batchSize {
queueCap = batchSize
}
bp.rxQueues = make([]chan *overlay.Packet, routines)
bp.txQueues = make([]chan queuedDatagram, routines)
bp.tunQueues = make([]chan *overlay.Packet, routines)
for i := 0; i < routines; i++ {
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
}
}
func (bp *batchPipelines) Pool() *overlay.PacketPool {
if bp == nil || !bp.enabled {
return nil
}
return bp.pool
}
func (bp *batchPipelines) Enabled() bool {
return bp != nil && bp.enabled
}
func (bp *batchPipelines) batchSizeHint() int {
if bp == nil || bp.batchSize <= 0 {
return 1
}
return bp.batchSize
}
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
return nil
}
return bp.rxQueues[i]
}
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
return nil
}
return bp.txQueues[i]
}
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
return nil
}
return bp.tunQueues[i]
}
func (bp *batchPipelines) txQueueLen(i int) int {
q := bp.txQueue(i)
if q == nil {
return 0
}
return len(q)
}
func (bp *batchPipelines) tunQueueLen(i int) int {
q := bp.tunQueue(i)
if q == nil {
return 0
}
return len(q)
}
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
q := bp.rxQueue(i)
if q == nil {
return false
}
q <- pkt
return true
}
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
q := bp.txQueue(i)
if q == nil {
return false
}
q <- queuedDatagram{packet: pkt, addr: addr}
return true
}
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
q := bp.tunQueue(i)
if q == nil {
return false
}
q <- pkt
return true
}
func (bp *batchPipelines) newPacket() *overlay.Packet {
if bp == nil || !bp.enabled || bp.pool == nil {
return nil
}
return bp.pool.Get()
}

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}

View File

@@ -1,6 +1,7 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -13,6 +14,7 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}

View File

@@ -15,6 +15,7 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{

View File

@@ -1,25 +1,34 @@
package cert
import (
"encoding/hex"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
)
const (
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +60,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -62,6 +81,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -73,7 +105,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner:
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256
@@ -108,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
}
}
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)

View File

@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -29,6 +29,8 @@ type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -54,25 +56,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM()

View File

@@ -237,7 +237,7 @@ func TestCertDowngrade(t *testing.T) {
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
@@ -318,50 +318,3 @@ func TestCertMismatchCorrection(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
return nil
@@ -490,12 +490,10 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
if localCache != nil && localCache.Has(fp) {
return true
}
}
conntrack := f.Conntrack
conntrack.Lock()
@@ -559,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
conntrack.Unlock()
if localCache != nil {
localCache[fp] = struct{}{}
localCache.Add(fp)
}
return true

View File

@@ -1,6 +1,7 @@
package firewall
import (
"sync"
"sync/atomic"
"time"
@@ -9,13 +10,58 @@ import (
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[Packet]struct{}
type ConntrackCache struct {
mu sync.Mutex
entries map[Packet]struct{}
}
func newConntrackCache() *ConntrackCache {
return &ConntrackCache{entries: make(map[Packet]struct{})}
}
func (c *ConntrackCache) Has(p Packet) bool {
if c == nil {
return false
}
c.mu.Lock()
_, ok := c.entries[p]
c.mu.Unlock()
return ok
}
func (c *ConntrackCache) Add(p Packet) {
if c == nil {
return
}
c.mu.Lock()
c.entries[p] = struct{}{}
c.mu.Unlock()
}
func (c *ConntrackCache) Len() int {
if c == nil {
return 0
}
c.mu.Lock()
l := len(c.entries)
c.mu.Unlock()
return l
}
func (c *ConntrackCache) Reset(capHint int) {
if c == nil {
return
}
c.mu.Lock()
c.entries = make(map[Packet]struct{}, capHint)
c.mu.Unlock()
}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick atomic.Uint64
cache ConntrackCache
cache *ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
c := &ConntrackCacheTicker{cache: newConntrackCache()}
go c.tick(d)
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
if c == nil {
return nil
}
if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if ll := c.cache.Len(); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
c.cache.Reset(ll)
}
}

View File

@@ -392,7 +392,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
@@ -692,6 +692,50 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
},
}
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.0.2.1"),
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ {

34
go.mod
View File

@@ -6,32 +6,33 @@ require (
dario.cat/mergo v1.0.2
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cilium/ebpf v0.12.3
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.20.4
github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.68
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.22.0
github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.37.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.39.0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0
golang.org/x/net v0.45.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.6
google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
@@ -43,11 +44,12 @@ require (
github.com/google/btree v1.1.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/mod v0.23.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.30.0 // indirect
golang.org/x/tools v0.33.0 // indirect
)

76
go.sum
View File

@@ -17,6 +17,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -24,8 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -64,8 +68,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
@@ -78,13 +82,14 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,24 +111,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -143,29 +148,33 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -176,8 +185,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -185,8 +194,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -197,18 +206,17 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -219,8 +227,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -239,8 +247,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -28,7 +28,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we must use a v2 cert
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
@@ -52,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -107,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -147,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp = "<error generating certificate fingerprint>"
}

View File

@@ -300,8 +300,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
InitiatorRelayIndex: idx,
}
relayFrom := hm.f.myVpnAddrs[0]
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
@@ -319,13 +317,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
if vpnIp.Is4() {
relayFrom = hm.f.myVpnAddrs[0]
} else {
//todo do this smarter
relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
}
m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -340,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": relayFrom,
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
@@ -366,8 +358,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
InitiatorRelayIndex: existingRelay.LocalIndex,
}
relayFrom := hm.f.myVpnAddrs[0]
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
@@ -385,14 +375,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
if vpnIp.Is4() {
relayFrom = hm.f.myVpnAddrs[0]
} else {
//todo do this smarter
relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
}
m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -407,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": relayFrom,
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).

View File

@@ -2,7 +2,6 @@ package nebula
import (
"errors"
"fmt"
"net"
"net/netip"
"slices"
@@ -522,7 +521,6 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
return nil, nil, errors.New("unable to find host")
}
lastH := h
for h != nil {
for _, targetIp := range targetIps {
r, ok := h.relayState.QueryRelayForByIp(targetIp)
@@ -530,12 +528,10 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
return h, r, nil
}
}
lastH = h
h = h.next
}
//todo no merge
return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH)
return nil, nil, errors.New("unable to find host with relay")
}
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
@@ -742,7 +738,8 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
i.networks = new(bart.Lite)
for _, network := range networks {
i.networks.Insert(network)
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
i.networks.Insert(nprefix)
}
for _, network := range unsafeNetworks {

View File

@@ -2,16 +2,18 @@ package nebula
import (
"net/netip"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
@@ -335,9 +337,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
target := remote
if !target.IsValid() {
target = hostinfo.remote
}
useRelay := !target.IsValid()
fullOut := out
var pkt *overlay.Packet
if !useRelay && f.batches.Enabled() {
pkt = f.batches.newPacket()
if pkt != nil {
out = pkt.Payload()[:0]
}
}
if useRelay {
if len(out) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
@@ -371,31 +385,62 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
var err error
if len(p) > 0 && slicesOverlap(out, p) {
tmp := make([]byte, len(p))
copy(tmp, p)
p = tmp
}
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("udpAddr", target).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
if target.IsValid() {
if pkt != nil {
pkt.Len = len(out)
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"queue": q,
"dest": target,
"payload_len": pkt.Len,
"use_batches": true,
"remote_index": hostinfo.remoteIndexId,
}).Debug("enqueueing packet to UDP batch queue")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
if f.tryQueuePacket(q, pkt, target) {
return
}
} else {
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"queue": q,
"dest": target,
}).Debug("failed to enqueue packet; falling back to immediate send")
}
f.writeImmediatePacket(q, pkt, target, hostinfo)
return
}
if f.tryQueueDatagram(q, out, target) {
return
}
f.writeImmediate(q, out, target, hostinfo)
return
}
// fall back to relay path
if pkt != nil {
pkt.Release()
}
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
@@ -407,5 +452,18 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
}
}
}
// slicesOverlap reports whether the two byte slices share any portion of memory.
// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions.
func slicesOverlap(a, b []byte) bool {
if len(a) == 0 || len(b) == 0 {
return false
}
aStart := uintptr(unsafe.Pointer(&a[0]))
aEnd := aStart + uintptr(len(a))
bStart := uintptr(unsafe.Pointer(&b[0]))
bEnd := bStart + uintptr(len(b))
return aStart < bEnd && bStart < aEnd
}

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"os"
"runtime"
"strings"
"sync/atomic"
"time"
@@ -21,7 +22,13 @@ import (
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
const (
mtu = 9001
defaultGSOFlushInterval = 150 * time.Microsecond
defaultBatchQueueDepthFactor = 4
defaultGSOMaxSegments = 8
maxKernelGSOSegments = 64
)
type InterfaceConfig struct {
HostMap *HostMap
@@ -36,6 +43,9 @@ type InterfaceConfig struct {
connectionManager *connectionManager
DropLocalBroadcast bool
DropMulticast bool
EnableGSO bool
EnableGRO bool
GSOMaxSegments int
routines int
MessageMetrics *MessageMetrics
version string
@@ -47,6 +57,8 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
BatchFlushInterval time.Duration
BatchQueueDepth int
l *logrus.Logger
}
@@ -84,9 +96,20 @@ type Interface struct {
version string
conntrackCacheTimeout time.Duration
batchQueueDepth int
enableGSO bool
enableGRO bool
gsoMaxSegments int
batchUDPQueueGauge metrics.Gauge
batchUDPFlushCounter metrics.Counter
batchTunQueueGauge metrics.Gauge
batchTunFlushCounter metrics.Counter
batchFlushInterval atomic.Int64
sendSem chan struct{}
writers []udp.Conn
readers []io.ReadWriteCloser
batches batchPipelines
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
@@ -161,6 +184,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no connection manager")
}
if c.GSOMaxSegments <= 0 {
c.GSOMaxSegments = defaultGSOMaxSegments
}
if c.GSOMaxSegments > maxKernelGSOSegments {
c.GSOMaxSegments = maxKernelGSOSegments
}
if c.BatchQueueDepth <= 0 {
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
}
if c.BatchFlushInterval < 0 {
c.BatchFlushInterval = 0
}
if c.BatchFlushInterval == 0 && c.EnableGSO {
c.BatchFlushInterval = defaultGSOFlushInterval
}
cs := c.pki.getCertState()
ifce := &Interface{
pki: c.pki,
@@ -186,6 +225,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
batchQueueDepth: c.BatchQueueDepth,
enableGSO: c.EnableGSO,
enableGRO: c.EnableGRO,
gsoMaxSegments: c.GSOMaxSegments,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
@@ -198,8 +241,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
ifce.sendSem = make(chan struct{}, c.routines)
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
if c.l.Level >= logrus.DebugLevel {
c.l.WithFields(logrus.Fields{
"enableGSO": c.EnableGSO,
"enableGRO": c.EnableGRO,
"gsoMaxSegments": c.GSOMaxSegments,
"batchQueueDepth": c.BatchQueueDepth,
"batchFlush": c.BatchFlushInterval,
"batching": ifce.batches.Enabled(),
}).Debug("initialized batch pipelines")
}
ifce.connectionManager.intf = ifce
@@ -248,6 +308,18 @@ func (f *Interface) run() {
go f.listenOut(i)
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
}
if f.batches.Enabled() {
for i := 0; i < f.routines; i++ {
go f.runInsideBatchWorker(i)
go f.runTunWriteQueue(i)
go f.runSendQueue(i)
}
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
@@ -279,6 +351,17 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
if f.batches.Enabled() {
if br, ok := reader.(overlay.BatchReader); ok {
f.listenInBatchLocked(reader, br, i)
return
}
}
f.listenInLegacyLocked(reader, i)
}
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
@@ -302,6 +385,581 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
}
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
pool := f.batches.Pool()
if pool == nil {
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
f.listenInLegacyLocked(raw, i)
return
}
for {
packets, err := reader.ReadIntoBatch(pool)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
if isVirtioHeadroomError(err) {
f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue")
f.listenInLegacyLocked(raw, i)
return
}
f.l.WithError(err).Error("Error while reading outbound packet batch")
os.Exit(2)
}
if len(packets) == 0 {
continue
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if !f.batches.enqueueRx(i, pkt) {
pkt.Release()
}
}
}
}
func (f *Interface) runInsideBatchWorker(i int) {
queue := f.batches.rxQueue(i)
if queue == nil {
return
}
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for pkt := range queue {
if pkt == nil {
continue
}
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
pkt.Release()
}
}
func (f *Interface) runSendQueue(i int) {
queue := f.batches.txQueue(i)
if queue == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
}
return
}
writer := f.writerForIndex(i)
if writer == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
}
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("send queue worker started")
}
defer func() {
if f.l.Level >= logrus.WarnLevel {
f.l.WithField("queue", i).Warn("send queue worker exited")
}
}()
batchCap := f.batches.batchSizeHint()
if batchCap <= 0 {
batchCap = 1
}
gsoLimit := f.effectiveGSOMaxSegments()
if gsoLimit > batchCap {
batchCap = gsoLimit
}
pending := make([]queuedDatagram, 0, batchCap)
var (
flushTimer *time.Timer
flushC <-chan time.Time
)
dispatch := func(reason string, timerFired bool) {
if len(pending) == 0 {
return
}
batch := pending
f.flushAndReleaseBatch(i, writer, batch, reason)
for idx := range batch {
batch[idx] = queuedDatagram{}
}
pending = pending[:0]
if flushTimer != nil {
if !timerFired {
if !flushTimer.Stop() {
select {
case <-flushTimer.C:
default:
}
}
}
flushTimer = nil
flushC = nil
}
}
armTimer := func() {
delay := f.currentBatchFlushInterval()
if delay <= 0 {
dispatch("nogso", false)
return
}
if flushTimer == nil {
flushTimer = time.NewTimer(delay)
flushC = flushTimer.C
}
}
for {
select {
case d := <-queue:
if d.packet == nil {
continue
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"queue": i,
"payload_len": d.packet.Len,
"dest": d.addr,
}).Debug("send queue received packet")
}
pending = append(pending, d)
if gsoLimit > 0 && len(pending) >= gsoLimit {
dispatch("gso", false)
continue
}
if len(pending) >= cap(pending) {
dispatch("cap", false)
continue
}
armTimer()
f.observeUDPQueueLen(i)
case <-flushC:
dispatch("timer", true)
}
}
}
func (f *Interface) runTunWriteQueue(i int) {
queue := f.batches.tunQueue(i)
if queue == nil {
return
}
writer := f.batches.inside
if writer == nil {
return
}
requiredHeadroom := writer.BatchHeadroom()
batchCap := f.batches.batchSizeHint()
if batchCap <= 0 {
batchCap = 1
}
pending := make([]*overlay.Packet, 0, batchCap)
var (
flushTimer *time.Timer
flushC <-chan time.Time
)
flush := func(reason string, timerFired bool) {
if len(pending) == 0 {
return
}
valid := pending[:0]
for idx := range pending {
if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) {
pending[idx] = nil
continue
}
if pending[idx] != nil {
valid = append(valid, pending[idx])
}
}
if len(valid) > 0 {
if _, err := writer.WriteBatch(valid); err != nil {
f.l.WithError(err).
WithField("queue", i).
WithField("reason", reason).
Warn("Failed to write tun batch")
for _, pkt := range valid {
if pkt != nil {
f.writePacketToTun(i, pkt)
}
}
}
}
pending = pending[:0]
if flushTimer != nil {
if !timerFired {
if !flushTimer.Stop() {
select {
case <-flushTimer.C:
default:
}
}
}
flushTimer = nil
flushC = nil
}
}
armTimer := func() {
delay := f.currentBatchFlushInterval()
if delay <= 0 {
return
}
if flushTimer == nil {
flushTimer = time.NewTimer(delay)
flushC = flushTimer.C
}
}
for {
select {
case pkt := <-queue:
if pkt == nil {
continue
}
if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") {
pending = append(pending, pkt)
}
if len(pending) >= cap(pending) {
flush("cap", false)
continue
}
armTimer()
f.observeTunQueueLen(i)
case <-flushC:
flush("timer", true)
}
}
}
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
if len(batch) == 0 {
return
}
f.flushDatagrams(index, writer, batch, reason)
for idx := range batch {
if batch[idx].packet != nil {
batch[idx].packet.Release()
batch[idx].packet = nil
}
}
if f.batchUDPFlushCounter != nil {
f.batchUDPFlushCounter.Inc(int64(len(batch)))
}
}
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
if len(batch) == 0 {
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"reason": reason,
"pending": len(batch),
}).Debug("udp batch flush summary")
}
maxSeg := f.effectiveGSOMaxSegments()
if bw, ok := writer.(udp.BatchConn); ok {
chunkCap := maxSeg
if chunkCap <= 0 {
chunkCap = len(batch)
}
chunk := make([]udp.Datagram, 0, chunkCap)
var (
currentAddr netip.AddrPort
segments int
)
flushChunk := func() {
if len(chunk) == 0 {
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"segments": len(chunk),
"dest": chunk[0].Addr,
"reason": reason,
"pending_total": len(batch),
}).Debug("flushing UDP batch")
}
if err := bw.WriteBatch(chunk); err != nil {
f.l.WithError(err).
WithField("writer", index).
WithField("reason", reason).
Warn("Failed to write UDP batch")
}
chunk = chunk[:0]
segments = 0
}
for _, item := range batch {
if item.packet == nil || !item.addr.IsValid() {
continue
}
payload := item.packet.Payload()[:item.packet.Len]
if segments == 0 {
currentAddr = item.addr
}
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
flushChunk()
currentAddr = item.addr
}
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
segments++
}
flushChunk()
return
}
for _, item := range batch {
if item.packet == nil || !item.addr.IsValid() {
continue
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"reason": reason,
"dest": item.addr,
"segments": 1,
}).Debug("flushing UDP batch")
}
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
f.l.WithError(err).
WithField("writer", index).
WithField("udpAddr", item.addr).
WithField("reason", reason).
Warn("Failed to write UDP packet")
}
}
}
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
if !addr.IsValid() || !f.batches.Enabled() {
return false
}
pkt := f.batches.newPacket()
if pkt == nil {
return false
}
payload := pkt.Payload()
if len(payload) < len(buf) {
pkt.Release()
return false
}
copy(payload, buf)
pkt.Len = len(buf)
if f.batches.enqueueTx(q, pkt, addr) {
f.observeUDPQueueLen(q)
return true
}
pkt.Release()
return false
}
func (f *Interface) writerForIndex(i int) udp.Conn {
if i < 0 || i >= len(f.writers) {
return nil
}
return f.writers[i]
}
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
writer := f.writerForIndex(q)
if writer == nil {
f.l.WithField("udpAddr", addr).
WithField("writer", q).
Error("Failed to write outgoing packet: no writer available")
return
}
if err := writer.WriteTo(buf, addr); err != nil {
hostinfo.logger(f.l).
WithError(err).
WithField("udpAddr", addr).
Error("Failed to write outgoing packet")
}
}
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
return false
}
if f.batches.enqueueTx(q, pkt, addr) {
f.observeUDPQueueLen(q)
return true
}
return false
}
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
if pkt == nil {
return
}
writer := f.writerForIndex(q)
if writer == nil {
f.l.WithField("udpAddr", addr).
WithField("writer", q).
Error("Failed to write outgoing packet: no writer available")
pkt.Release()
return
}
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
hostinfo.logger(f.l).
WithError(err).
WithField("udpAddr", addr).
Error("Failed to write outgoing packet")
}
pkt.Release()
}
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
if pkt == nil {
return
}
writer := f.readers[q]
if writer == nil {
pkt.Release()
return
}
if bw, ok := writer.(interface {
WriteBatch([]*overlay.Packet) (int, error)
}); ok {
if _, err := bw.WriteBatch([]*overlay.Packet{pkt}); err != nil {
f.l.WithError(err).WithField("queue", q).Warn("Failed to write tun packet via batch writer")
pkt.Release()
}
return
}
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
f.l.WithError(err).Error("Failed to write to tun")
}
pkt.Release()
}
func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet {
if pkt == nil {
return nil
}
payload := pkt.Payload()[:pkt.Len]
if len(payload) == 0 && required <= 0 {
return pkt
}
pool := f.batches.Pool()
if pool != nil {
if clone := pool.Get(); clone != nil {
if len(clone.Payload()) >= len(payload) {
clone.Len = copy(clone.Payload(), payload)
pkt.Release()
return clone
}
clone.Release()
}
}
if required < 0 {
required = 0
}
buf := make([]byte, required+len(payload))
n := copy(buf[required:], payload)
pkt.Release()
return &overlay.Packet{
Buf: buf,
Offset: required,
Len: n,
}
}
func (f *Interface) observeUDPQueueLen(i int) {
if f.batchUDPQueueGauge == nil {
return
}
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
}
func (f *Interface) observeTunQueueLen(i int) {
if f.batchTunQueueGauge == nil {
return
}
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
}
func (f *Interface) currentBatchFlushInterval() time.Duration {
if v := f.batchFlushInterval.Load(); v > 0 {
return time.Duration(v)
}
return 0
}
func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool {
p := *pkt
if p == nil {
return false
}
if required <= 0 || p.Offset >= required {
return true
}
clone := f.clonePacketWithHeadroom(p, required)
if clone == nil {
f.l.WithFields(logrus.Fields{
"queue": queue,
"reason": reason,
}).Warn("dropping packet lacking tun headroom")
return false
}
*pkt = clone
return true
}
func isVirtioHeadroomError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio")
}
func (f *Interface) effectiveGSOMaxSegments() int {
max := f.gsoMaxSegments
if max <= 0 {
max = defaultGSOMaxSegments
}
if max > maxKernelGSOSegments {
max = maxKernelGSOSegments
}
if !f.enableGSO {
return 1
}
return max
}
type udpOffloadConfigurator interface {
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
}
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
if maxSegments <= 0 {
maxSegments = defaultGSOMaxSegments
}
if maxSegments > maxKernelGSOSegments {
maxSegments = maxKernelGSOSegments
}
f.enableGSO = enableGSO
f.enableGRO = enableGRO
f.gsoMaxSegments = maxSegments
for _, writer := range f.writers {
if cfg, ok := writer.(udpOffloadConfigurator); ok {
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
}
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
@@ -404,6 +1062,42 @@ func (f *Interface) reloadMisc(c *config.C) {
f.reQueryWait.Store(int64(n))
f.l.Info("timers.requery_wait_duration has changed")
}
if c.HasChanged("listen.gso_flush_timeout") {
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
if d < 0 {
d = 0
}
f.batchFlushInterval.Store(int64(d))
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
} else if c.HasChanged("batch.flush_interval") {
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
if d < 0 {
d = 0
}
f.batchFlushInterval.Store(int64(d))
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
}
if c.HasChanged("batch.queue_depth") {
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
if n != f.batchQueueDepth {
f.batchQueueDepth = n
f.l.Warn("batch.queue_depth changes require a restart to take effect")
}
}
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
f.l.WithFields(logrus.Fields{
"enableGSO": enableGSO,
"enableGRO": enableGRO,
"gsoMaxSegments": maxSeg,
}).Info("listen GSO/GRO configuration updated")
}
}
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {

38
main.go
View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"time"
"github.com/sirupsen/logrus"
@@ -143,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// set up our UDP listener
udpConns := make([]udp.Conn, routines)
port := c.GetInt("listen.port", 0)
enableGSO := c.GetBool("listen.enable_gso", true)
enableGRO := c.GetBool("listen.enable_gro", true)
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if gsoMaxSegments <= 0 {
gsoMaxSegments = defaultGSOMaxSegments
}
if gsoMaxSegments > maxKernelGSOSegments {
gsoMaxSegments = maxKernelGSOSegments
}
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
if gsoFlushTimeout < 0 {
gsoFlushTimeout = 0
}
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
if !configTest {
rawListenHost := c.GetString("listen.host", "0.0.0.0")
@@ -162,13 +177,28 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
listenHost = ips[0].Unmap()
}
useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
var mkListener func(*logrus.Logger, netip.Addr, int, bool, int, int) (udp.Conn, error)
if useWG {
mkListener = udp.NewWireguardListener
} else {
mkListener = udp.NewListener
}
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64), i)
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
//todo set bpf on zeroth socket
udpServer.ReloadConfig(c)
if cfg, ok := udpServer.(interface {
ConfigureOffload(bool, bool, int)
}); ok {
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
}
udpConns[i] = udpServer
// If port is dynamic, discover it before the next pass through the for loop
@@ -236,12 +266,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
EnableGSO: enableGSO,
EnableGRO: enableGRO,
GSOMaxSegments: gsoMaxSegments,
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
BatchFlushInterval: gsoFlushTimeout,
BatchQueueDepth: batchQueueDepth,
l: l,
}
@@ -253,6 +288,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
ifce.writers = udpConns
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
lightHouse.ifce = ifce
ifce.RegisterConfigChangeCallbacks(c)

View File

@@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"golang.org/x/net/ipv4"
)
@@ -19,7 +20,7 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -61,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, ip, h.RemoteIndex) {
return
}
case header.MessageRelay:
@@ -465,23 +466,45 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
var (
err error
pkt *overlay.Packet
)
if f.batches.tunQueue(q) != nil {
pkt = f.batches.newPacket()
if pkt != nil {
out = pkt.Payload()[:0]
}
}
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
if addr.IsValid() {
f.maybeSendRecvError(addr, recvIndex)
}
return false
}
err = newPacket(out, true, fwPacket)
if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return false
@@ -489,6 +512,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
if pkt != nil {
pkt.Release()
}
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
@@ -501,8 +527,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
}
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
if pkt != nil {
pkt.Len = len(out)
if f.batches.enqueueTun(q, pkt) {
f.observeTunQueueLen(q)
return true
}
f.writePacketToTun(q, pkt)
return true
}
if _, err = f.readers[q].Write(out); err != nil {
f.l.WithError(err).Error("Failed to write to tun")
}
return true

View File

@@ -3,6 +3,7 @@ package overlay
import (
"io"
"net/netip"
"sync"
"github.com/slackhq/nebula/routing"
)
@@ -15,3 +16,84 @@ type Device interface {
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
// Packet represents a single packet buffer with optional headroom to carry
// metadata (for example virtio-net headers).
type Packet struct {
Buf []byte
Offset int
Len int
release func()
}
func (p *Packet) Payload() []byte {
return p.Buf[p.Offset : p.Offset+p.Len]
}
func (p *Packet) Reset() {
p.Len = 0
p.Offset = 0
p.release = nil
}
func (p *Packet) Release() {
if p.release != nil {
p.release()
p.release = nil
}
}
func (p *Packet) Capacity() int {
return len(p.Buf) - p.Offset
}
// PacketPool manages reusable buffers with headroom.
type PacketPool struct {
headroom int
blksz int
pool sync.Pool
}
func NewPacketPool(headroom, payload int) *PacketPool {
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
p.pool.New = func() any {
buf := make([]byte, p.blksz)
return &Packet{Buf: buf, Offset: headroom}
}
return p
}
func (p *PacketPool) Get() *Packet {
pkt := p.pool.Get().(*Packet)
pkt.Offset = p.headroom
pkt.Len = 0
pkt.release = func() { p.put(pkt) }
return pkt
}
func (p *PacketPool) put(pkt *Packet) {
pkt.Reset()
p.pool.Put(pkt)
}
// BatchReader allows reading multiple packets into a shared pool with
// preallocated headroom (e.g. virtio-net headers).
type BatchReader interface {
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
}
// BatchWriter writes a slice of packets that carry their own metadata.
type BatchWriter interface {
WriteBatch(packets []*Packet) (int, error)
}
// BatchCapableDevice describes a device that can efficiently read and write
// batches of packets with virtio headroom.
type BatchCapableDevice interface {
Device
BatchReader
BatchWriter
BatchHeadroom() int
BatchPayloadCap() int
BatchSize() int
}

View File

@@ -1,6 +1,8 @@
package overlay
import (
"fmt"
"net"
"net/netip"
"github.com/sirupsen/logrus"
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gateway := range gateways {
if dest.Addr().Is4() && gateway.Addr().Is4() {
return gateway, nil
}
if dest.Addr().Is6() && gateway.Addr().Is6() {
return gateway, nil
}
}
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
@@ -554,13 +552,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -10,11 +10,9 @@ import (
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
)
type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
}
type ifreqRename struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
Data uintptr
}
type ifreqDestroy struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
pad [16]byte
}
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
io.ReadWriteCloser
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil {
return err
t.l.WithError(err).Error("Error closing device")
}
t.devFd = -1
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
}
return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var fd int
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
}
if err != nil {
return nil, err
}
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
// Read the name of the interface
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, err
}
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
}
err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
if cidr.Addr().Is4() {
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
if cidr.Addr().Is6() {
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +500,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync/atomic"
"time"
@@ -19,6 +20,7 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
wgtun "github.com/slackhq/nebula/wgstack/tun"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
@@ -33,6 +35,7 @@ type tun struct {
TXQueueLen int
deviceIndex int
ioctlFd uintptr
wgDevice wgtun.Device
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -68,7 +71,9 @@ type ifreqQLEN struct {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
if err != nil {
return nil, err
}
@@ -113,7 +118,9 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
if err != nil {
return nil, err
}
@@ -123,16 +130,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
var (
rw io.ReadWriteCloser = file
fd = int(file.Fd())
wgDev wgtun.Device
)
if useWireguard {
dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
if err != nil {
return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
}
wgDev = dev
rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
fd = int(dev.File().Fd())
}
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
ReadWriteCloser: rw,
fd: fd,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l,
}
if wgDev != nil {
t.wgDevice = wgDev
}
if wgDev != nil {
// replace ioctl fd with device file descriptor to keep route management working
file = wgDev.File()
t.fd = int(file.Fd())
t.ioctlFd = file.Fd()
}
if t.ioctlFd == 0 {
t.ioctlFd = file.Fd()
}
err := t.reload(c, true)
if err != nil {
@@ -678,6 +714,14 @@ func (t *tun) Close() error {
_ = t.ReadWriteCloser.Close()
}
if t.wgDevice != nil {
_ = t.wgDevice.Close()
if t.ioctlFd > 0 {
// underlying fd already closed by the device
t.ioctlFd = 0
}
}
if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
}

View File

@@ -0,0 +1,56 @@
//go:build linux && !android && !e2e_testing
package overlay
import "fmt"
func (t *tun) batchIO() (*wireguardTunIO, bool) {
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
return io, ok
}
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
io, ok := t.batchIO()
if !ok {
return nil, fmt.Errorf("wireguard batch I/O not enabled")
}
return io.ReadIntoBatch(pool)
}
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
io, ok := t.batchIO()
if ok {
return io.WriteBatch(packets)
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
return 0, err
}
pkt.Release()
}
return len(packets), nil
}
func (t *tun) BatchHeadroom() int {
if io, ok := t.batchIO(); ok {
return io.BatchHeadroom()
}
return 0
}
func (t *tun) BatchPayloadCap() int {
if io, ok := t.batchIO(); ok {
return io.BatchPayloadCap()
}
return 0
}
func (t *tun) BatchSize() int {
if io, ok := t.batchIO(); ok {
return io.BatchSize()
}
return 1
}

View File

@@ -4,13 +4,12 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
@@ -20,11 +19,42 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
const (
SIOCAIFADDR_IN6 = 0x8080696b
TUNSIFHEAD = 0x80047442
TUNSIFMODE = 0x80047458
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type tun struct {
@@ -34,40 +64,18 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
f *os.File
fd int
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
@@ -77,13 +85,19 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
ReadWriteCloser: file,
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifr := ifreq{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
return err
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
return nil
}
// Unsafe path routes
return t.addRoutes(false)
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime = addrLifetime{
Vltime: 0xffffffff,
Pltime: 0xffffffff,
}
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
mode := int32(unix.IFF_BROADCAST)
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
if err != nil {
return fmt.Errorf("failed to set tun device mode: %w", err)
}
v := 1
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
if err != nil {
return fmt.Errorf("failed to set tun device head: %w", err)
}
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -197,21 +396,23 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -224,10 +425,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -242,3 +441,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -4,23 +4,50 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
SIOCAIFADDR_IN6 = 0x8080691a
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime [2]uint32
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
@@ -28,44 +55,42 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
ReadWriteCloser: file,
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
err = addRoute(cidr, t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime[0] = 0xffffffff
req.Lifetime[1] = 0xffffffff
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
@@ -124,86 +297,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
@@ -213,43 +311,159 @@ func (t *tun) Name() string {
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
copy(buf[4:], from)
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -0,0 +1,220 @@
//go:build linux && !android && !e2e_testing
package overlay
import (
"fmt"
"sync"
wgtun "github.com/slackhq/nebula/wgstack/tun"
)
type wireguardTunIO struct {
dev wgtun.Device
mtu int
batchSize int
readMu sync.Mutex
readBuffers [][]byte
readLens []int
legacyBuf []byte
writeMu sync.Mutex
writeBuf []byte
writeWrap [][]byte
writeBuffers [][]byte
}
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
batch := dev.BatchSize()
if batch <= 0 {
batch = 1
}
if mtu <= 0 {
mtu = DefaultMTU
}
return &wireguardTunIO{
dev: dev,
mtu: mtu,
batchSize: batch,
readLens: make([]int, batch),
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
writeWrap: make([][]byte, 1),
}
}
func (w *wireguardTunIO) Read(p []byte) (int, error) {
w.readMu.Lock()
defer w.readMu.Unlock()
bufs := w.readBuffers
if len(bufs) == 0 {
bufs = [][]byte{w.legacyBuf}
w.readBuffers = bufs
}
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, nil
}
length := w.readLens[0]
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
return length, nil
}
func (w *wireguardTunIO) Write(p []byte) (int, error) {
if len(p) > w.mtu {
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
}
w.writeMu.Lock()
defer w.writeMu.Unlock()
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
buf[i] = 0
}
copy(buf[wgtun.VirtioNetHdrLen:], p)
w.writeWrap[0] = buf
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
if err != nil {
return n, err
}
return len(p), nil
}
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
if pool == nil {
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
}
w.readMu.Lock()
defer w.readMu.Unlock()
if len(w.readBuffers) < w.batchSize {
w.readBuffers = make([][]byte, w.batchSize)
}
if len(w.readLens) < w.batchSize {
w.readLens = make([]int, w.batchSize)
}
packets := make([]*Packet, w.batchSize)
requiredHeadroom := w.BatchHeadroom()
requiredPayload := w.BatchPayloadCap()
headroom := 0
for i := 0; i < w.batchSize; i++ {
pkt := pool.Get()
if pkt == nil {
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
}
if pkt.Capacity() < requiredPayload {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
}
if i == 0 {
headroom = pkt.Offset
if headroom < requiredHeadroom {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
}
} else if pkt.Offset != headroom {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
}
packets[i] = pkt
w.readBuffers[i] = pkt.Buf
}
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
if err != nil {
releasePackets(packets)
return nil, err
}
if n == 0 {
releasePackets(packets)
return nil, nil
}
for i := 0; i < n; i++ {
packets[i].Len = w.readLens[i]
}
for i := n; i < w.batchSize; i++ {
packets[i].Release()
packets[i] = nil
}
return packets[:n], nil
}
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
if len(packets) == 0 {
return 0, nil
}
requiredHeadroom := w.BatchHeadroom()
offset := packets[0].Offset
if offset < requiredHeadroom {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if pkt.Offset != offset {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
}
limit := pkt.Offset + pkt.Len
if limit > len(pkt.Buf) {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
}
}
w.writeMu.Lock()
defer w.writeMu.Unlock()
if len(w.writeBuffers) < len(packets) {
w.writeBuffers = make([][]byte, len(packets))
}
for i, pkt := range packets {
if pkt == nil {
w.writeBuffers[i] = nil
continue
}
limit := pkt.Offset + pkt.Len
w.writeBuffers[i] = pkt.Buf[:limit]
}
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
if err != nil {
return n, err
}
releasePackets(packets)
return n, nil
}
func (w *wireguardTunIO) BatchHeadroom() int {
return wgtun.VirtioNetHdrLen
}
func (w *wireguardTunIO) BatchPayloadCap() int {
return w.mtu
}
func (w *wireguardTunIO) BatchSize() int {
return w.batchSize
}
func (w *wireguardTunIO) Close() error {
return nil
}
func releasePackets(pkts []*Packet) {
for _, pkt := range pkts {
if pkt != nil {
pkt.Release()
}
}
}

View File

@@ -190,7 +190,6 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
InitiatorRelayIndex: peerRelay.RemoteIndex,
}
relayFrom := h.vpnAddrs[0]
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
@@ -208,13 +207,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
b = targetAddr.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
if targetAddr.Is4() {
relayFrom = h.vpnAddrs[0]
} else {
//todo do this smarter
relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
}
resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
resp.RelayToAddr = target
}
@@ -367,7 +360,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
}
relayFrom := h.vpnAddrs[0]
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
@@ -384,13 +377,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
b = target.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
if target.Is4() {
relayFrom = h.vpnAddrs[0]
} else {
//todo do this smarter
relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
}
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
req.RelayToAddr = netAddrToProtoAddr(target)
}
@@ -401,7 +388,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": relayFrom,
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,

View File

@@ -22,6 +22,18 @@ type Conn interface {
Close() error
}
// Datagram represents a UDP payload destined to a specific address.
type Datagram struct {
Payload []byte
Addr netip.AddrPort
}
// BatchConn can send multiple datagrams in one syscall.
type BatchConn interface {
Conn
WriteBatch(pkts []Datagram) error
}
type NoopConn struct{}
func (NoopConn) Rebind() error {

View File

@@ -32,7 +32,7 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
af := unix.AF_INET6
if ip.Is4() {
af = unix.AF_INET
@@ -310,31 +310,51 @@ func (u *StdConn) Close() error {
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns {
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
if len(udpConns) == 0 {
return func() {}
}
type statsProvider struct {
index int
conn *StdConn
}
providers := make([]statsProvider, 0, len(udpConns))
for i, c := range udpConns {
if sc, ok := c.(*StdConn); ok {
providers = append(providers, statsProvider{index: i, conn: sc})
}
}
if len(providers) == 0 {
return func() {}
}
var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := providers[0].conn.getMemInfo(&meminfo); err != nil {
return func() {}
}
udpGauges := make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(providers))
for i, provider := range providers {
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", provider.index), nil),
}
}
return func() {
for i, gauges := range udpGauges {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
for i, provider := range providers {
if err := provider.conn.getMemInfo(&meminfo); err == nil {
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
udpGauges[i][j].Update(int64(meminfo[j]))
}
}
}

226
udp/wireguard_conn_linux.go Normal file
View File

@@ -0,0 +1,226 @@
//go:build linux && !android && !e2e_testing
package udp
import (
"errors"
"net"
"net/netip"
"sync"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
wgconn "github.com/slackhq/nebula/wgstack/conn"
)
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
type WGConn struct {
l *logrus.Logger
bind *wgconn.StdNetBind
recvers []wgconn.ReceiveFunc
batch int
reqBatch int
localIP netip.Addr
localPort uint16
enableGSO bool
enableGRO bool
gsoMaxSeg int
closed atomic.Bool
q int
closeOnce sync.Once
}
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
bind := wgconn.NewStdNetBindForAddr(ip, multi, q)
recvers, actualPort, err := bind.Open(uint16(port))
if err != nil {
return nil, err
}
if batch <= 0 {
batch = bind.BatchSize()
} else if batch > bind.BatchSize() {
batch = bind.BatchSize()
}
return &WGConn{
l: l,
bind: bind,
recvers: recvers,
batch: batch,
reqBatch: batch,
localIP: ip,
localPort: actualPort,
q: q,
}, nil
}
func (c *WGConn) Rebind() error {
// WireGuard's bind does not support rebinding in place.
return nil
}
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
// Fallback to wildcard IPv4 for display purposes.
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
}
return netip.AddrPortFrom(c.localIP, c.localPort), nil
}
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
batchSize := c.batch
packets := make([][]byte, batchSize)
for i := range packets {
packets[i] = make([]byte, 0xffff)
}
sizes := make([]int, batchSize)
endpoints := make([]wgconn.Endpoint, batchSize)
for {
if c.closed.Load() {
return
}
n, err := fn(packets, sizes, endpoints)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
if c.l != nil {
c.l.WithError(err).Debug("wireguard UDP listener receive error")
}
continue
}
for i := 0; i < n; i++ {
if sizes[i] == 0 {
continue
}
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
if !ok {
if c.l != nil {
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
}
continue
}
addr := stdEp.AddrPort
r(addr, packets[i][:sizes[i]])
endpoints[i] = nil
}
}
}
func (c *WGConn) ListenOut(r EncReader) {
for _, fn := range c.recvers {
go c.listen(fn, r)
}
}
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
if len(b) == 0 {
return nil
}
if c.closed.Load() {
return net.ErrClosed
}
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
return c.bind.Send([][]byte{b}, ep)
}
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
if len(datagrams) == 0 {
return nil
}
if c.closed.Load() {
return net.ErrClosed
}
max := c.batch
if max <= 0 {
max = len(datagrams)
if max == 0 {
max = 1
}
}
bufs := make([][]byte, 0, max)
var (
current netip.AddrPort
endpoint *wgconn.StdNetEndpoint
haveAddr bool
)
flush := func() error {
if len(bufs) == 0 || endpoint == nil {
bufs = bufs[:0]
return nil
}
err := c.bind.Send(bufs, endpoint)
bufs = bufs[:0]
return err
}
for _, d := range datagrams {
if len(d.Payload) == 0 || !d.Addr.IsValid() {
continue
}
if !haveAddr || d.Addr != current {
if err := flush(); err != nil {
return err
}
current = d.Addr
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
haveAddr = true
}
bufs = append(bufs, d.Payload)
if len(bufs) >= max {
if err := flush(); err != nil {
return err
}
}
}
return flush()
}
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
c.enableGSO = enableGSO
c.enableGRO = enableGRO
if maxSegments <= 0 {
maxSegments = 1
} else if maxSegments > wgconn.IdealBatchSize {
maxSegments = wgconn.IdealBatchSize
}
c.gsoMaxSeg = maxSegments
effectiveBatch := c.reqBatch
if enableGSO && c.bind != nil {
bindBatch := c.bind.BatchSize()
if effectiveBatch < bindBatch {
if c.l != nil {
c.l.WithFields(logrus.Fields{
"requested": c.reqBatch,
"effective": bindBatch,
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
}
effectiveBatch = bindBatch
}
}
c.batch = effectiveBatch
if c.l != nil {
c.l.WithFields(logrus.Fields{
"enableGSO": enableGSO,
"enableGRO": enableGRO,
"gsoMaxSegments": maxSegments,
}).Debug("configured wireguard UDP offload")
}
}
func (c *WGConn) ReloadConfig(*config.C) {
// WireGuard bind currently does not expose runtime configuration knobs.
}
func (c *WGConn) Close() error {
var err error
c.closeOnce.Do(func() {
c.closed.Store(true)
err = c.bind.Close()
})
return err
}

View File

@@ -0,0 +1,15 @@
//go:build !linux || android || e2e_testing
package udp
import (
"fmt"
"net/netip"
"github.com/sirupsen/logrus"
)
// NewWireguardListener is only available on Linux builds.
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
}

587
wgstack/conn/bind_std.go Normal file
View File

@@ -0,0 +1,587 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
q int
}
// NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind().(*StdNetBind)
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind {
b := NewStdNetBind()
b.q = q
//if addr.IsValid() {
// if addr.IsUnspecified() {
// // keep dual-stack defaults with empty listen addresses
// } else if addr.Is4() {
// b.listenAddr4 = addr.Unmap().String()
// b.bindV4 = true
// b.bindV6 = false
// } else {
// b.listenAddr6 = addr.Unmap().String()
// b.bindV6 = true
// b.bindV4 = false
// }
//}
//b.reusePort = reusePort
return b
}
func newStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
}
func (e *StdNetEndpoint) ClearSrc() {
if e.src != nil {
// Truncate src, no need to reallocate.
e.src = e.src[:0]
}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func listenNet(network string, port int, q int) (*net.UDPConn, int, error) {
lc := listenConfig(q)
conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
if q == 0 {
if EvilFdZero == 0 {
panic("fuck")
}
err = reusePortHax(EvilFdZero)
if err != nil {
return nil, 0, fmt.Errorf("reuse port hax: %v", err)
}
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, port, err = listenNet("udp4", port, s.q)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port, s.q)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
br = s.ipv6PC
is6 = true
offload = s.ipv6TxOffload
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
return err
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}

131
wgstack/conn/conn.go Normal file
View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"errors"
"fmt"
"net/netip"
"reflect"
"runtime"
"strings"
)
const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection.
// fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
}
// PeekLookAtSocketFd is implemented by Bind objects that support having their
// file descriptor peeked at. Used by wireguard-android.
type PeekLookAtSocketFd interface {
PeekLookAtSocketFd4() (fd int, err error)
PeekLookAtSocketFd6() (fd int, err error)
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst: the remote address of a peer ("endpoint" in uapi terminology)
// src: the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() netip.Addr
SrcIP() netip.Addr
}
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}

222
wgstack/conn/controlfns.go Normal file
View File

@@ -0,0 +1,222 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"fmt"
"net"
"syscall"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/asm"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
const SO_ATTACH_REUSEPORT_EBPF = 52
//Create eBPF program that returns a hash to distribute packets
func createReuseportProgram() (*ebpf.Program, error) {
// This program uses the packet's hash and returns it modulo number of sockets
// Simple version: just return a counter-based distribution
//instructions := asm.Instructions{
// // Load the skb->hash value (already computed by kernel)
// asm.LoadMem(asm.R0, asm.R1, int16(unsafe.Offsetof(unix.XDPMd{}.RxQueueIndex)), asm.Word),
// asm.Return(),
//}
//
//// Alternative: simpler round-robin approach
//// This returns the CPU number, effectively round-robin
//instructions := asm.Instructions{
// asm.Mov.Reg(asm.R0, asm.R1), // Move ctx to R0
// asm.LoadMem(asm.R0, asm.R1, 0, asm.Word), // Load some field
// asm.Return(),
//}
// Better: Use BPF helper to get random/hash value
//instructions := asm.Instructions{
// // Call get_prandom_u32() to get random value for distribution
// asm.Mov.Imm(asm.R0, 0),
// asm.Call.Label("get_prandom_u32"),
// asm.Return(),
//}
//
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
// Type: ebpf.SocketFilter,
// Instructions: instructions,
// License: "GPL",
//})
//instructions := asm.Instructions{
// // R1 contains pointer to skb
// // Load skb->hash at offset 0x20 (may vary by kernel, but 0x20 is common)
// asm.LoadMem(asm.R0, asm.R1, 0x20, asm.Word),
//
// // If hash is 0, use rxhash instead (fallback)
// asm.JEq.Imm(asm.R0, 0, "use_rxhash"),
// asm.Return().Sym("return"),
//
// // Fallback: load rxhash
// asm.LoadMem(asm.R0, asm.R1, 0x24, asm.Word).Sym("use_rxhash"),
// asm.Return(),
//}
//
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
// Type: ebpf.SkReuseport,
// Instructions: instructions,
// License: "GPL",
//})
//instructions := asm.Instructions{
// // R1 = ctx (sk_reuseport_md)
// // R2 = sk_reuseport map (we'll use NULL/0 for default behavior)
// // R3 = key (select socket index)
// // R4 = flags
//
// // Simple approach: use the hash field from sk_reuseport_md
// // struct sk_reuseport_md { ... __u32 hash; ... } at offset 24
// asm.Mov.Reg(asm.R6, asm.R1), // Save ctx
//
// // Load the hash value at offset 24
// asm.LoadMem(asm.R2, asm.R6, 24, asm.Word),
//
// // Call bpf_sk_select_reuseport(ctx, map, key, flags)
// asm.Mov.Reg(asm.R1, asm.R6), // ctx
// asm.Mov.Imm(asm.R2, 0), // map (NULL = use default)
// asm.Mov.Reg(asm.R3, asm.R2), // key = hash we loaded (in R2)
// asm.Mov.Imm(asm.R4, 0), // flags
// asm.Call.Label("sk_select_reuseport"),
//
// // Return 0
// asm.Mov.Imm(asm.R0, 0),
// asm.Return(),
//}
//
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
// Type: ebpf.SkReuseport,
// Instructions: instructions,
// License: "GPL",
//})
instructions := asm.Instructions{
// R1 = ctx (sk_reuseport_md pointer)
// Load hash from sk_reuseport_md at offset 24
//asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
// R1 = ctx (save it)
asm.Mov.Reg(asm.R6, asm.R1),
// Prepare string on stack: "BPF called!\n"
// We need to build the format string on the stack
asm.Mov.Reg(asm.R1, asm.R10), // R1 = frame pointer
asm.Add.Imm(asm.R1, -16), // R1 = stack location for string
// Write "BPF called!\n" to stack (we'll use a simpler version)
// Store immediate 64-bit values
asm.StoreImm(asm.R1, 0, 0x2066706220, asm.DWord), // "bpf "
asm.StoreImm(asm.R1, 8, 0x0a21, asm.DWord), // "!\n"
// Call bpf_trace_printk(fmt, fmt_size)
// R1 already points to format string
asm.Mov.Imm(asm.R2, 16), // R2 = format size
asm.Call.Label("bpf_printk"),
// Return 0 (send to socket 0 for testing)
asm.Mov.Imm(asm.R0, 0),
asm.Return(),
//asm.Mov.Imm(asm.R0, 0),
//// Just return the hash directly
//// The kernel will automatically modulo by number of sockets
//asm.Return(),
}
prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
Type: ebpf.SkReuseport,
Instructions: instructions,
License: "GPL",
})
return prog, err
}
//func createReuseportProgram() (*ebpf.Program, error) {
// // Try offset 20 (common in newer kernels)
// instructions := asm.Instructions{
// asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
// asm.Return(),
// }
//
// prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
// Type: ebpf.SkReuseport,
// Instructions: instructions,
// License: "GPL",
// })
//
// return prog, err
//}
func reusePortHax(fd uintptr) error {
prog, err := createReuseportProgram()
if err != nil {
return fmt.Errorf("failed to create eBPF program: %w", err)
}
//defer prog.Close()
sockErr := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, prog.FD())
if sockErr != nil {
return sockErr
}
return nil
}
var EvilFdZero uintptr
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig(q int) *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
if q == 0 {
c.Control(func(fd uintptr) {
EvilFdZero = fd
})
// var e error
// err := c.Control(func(fd uintptr) {
// e = reusePortHax(fd)
// })
// if err != nil {
// return err
// }
// if e != nil {
// return e
// }
}
return nil
},
}
}

View File

@@ -0,0 +1,66 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!!
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) //todo!!!
_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
//print(err.Error())
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

9
wgstack/conn/default.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !windows
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }

View File

@@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

View File

@@ -0,0 +1,26 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
}
return false
}

View File

@@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

View File

@@ -0,0 +1,33 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
a := 0
err = rc.Control(func(fd uintptr) {
a, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = err == nil
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
rxOffload = errSyscall == nil && opt == 1
})
fmt.Printf("%d", a)
if err != nil {
return false, false
}
return txOffload, rxOffload
}

View File

@@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
wgstack/conn/gso_linux.go Normal file
View File

@@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

64
wgstack/conn/mark_unix.go Normal file
View File

@@ -0,0 +1,64 @@
//go:build linux || openbsd || freebsd
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"runtime"
"golang.org/x/sys/unix"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,42 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View File

@@ -0,0 +1,105 @@
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
if cap(*control) < len(ep.src) {
return
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

42
wgstack/tun/checksum.go Normal file
View File

@@ -0,0 +1,42 @@
package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
i := 0
n := len(b)
for n >= 4 {
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
n -= 4
i += 4
}
for n >= 2 {
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
}
if n == 1 {
ac += uint64(b[i]) << 8
}
return ac
}
func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

3
wgstack/tun/export.go Normal file
View File

@@ -0,0 +1,3 @@
package tun
const VirtioNetHdrLen = virtioNetHdrLen

View File

@@ -0,0 +1,630 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
import (
"bytes"
"encoding/binary"
"errors"
"io"
"unsafe"
wgconn "github.com/slackhq/nebula/wgstack/conn"
"golang.org/x/sys/unix"
)
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
const tcpFlagsOffset = 13
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = 0
coalescePSHEnding coalesceResult = 1
coalesceItemInvalidCSum coalesceResult = 2
coalescePktInvalidCSum coalesceResult = 3
coalesceSuccess coalesceResult = 4
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(headersLen),
gsoSize: uint16(item.gsoSize),
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
// (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
}
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
// Calculate the pseudo header checksum and place it at the TCP checksum
// offset. Downstream checksum offloading will combine this with computation
// of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := bufsOffset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It will return false when pktI is not
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
// should be written to the Device.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return false
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return false
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return false
}
}
if len(pkt) < iphLen {
return false
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return false
}
if len(pkt) < iphLen+tcphLen {
return false
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return false
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return false
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return false
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return false
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return true
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return false
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return false
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var coalesced bool
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
}
if !coalesced {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
*toWrite = append(*toWrite, i)
}
}
return nil
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksum(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
return nil
}

52
wgstack/tun/tun.go Normal file
View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
import (
"os"
)
type Event int
const (
EventUp = 1 << iota
EventDown
EventMTUUpdate
)
type Device interface {
// File returns the file descriptor of the device.
File() *os.File
// Read one or more packets from the Device (without any additional headers).
// On a successful read it returns the number of packets read, and sets
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
// A nonzero offset can be used to instruct the Device on where to begin
// reading into each element of the bufs slice.
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
// Write one or more packets to the device (without any additional headers).
// On a successful write it returns the number of packets written. A nonzero
// offset can be used to instruct the Device on where to begin writing from
// each packet contained within the bufs slice.
Write(bufs [][]byte, offset int) (int, error)
// MTU returns the MTU of the Device.
MTU() (int, error)
// Name returns the current name of the Device.
Name() (string, error)
// Events returns a channel of type Event, which is fed Device events.
Events() <-chan Event
// Close stops the Device and closes the Event channel.
Close() error
// BatchSize returns the preferred/max number of packets that can be read or
// written in a single read/write call. BatchSize must not change over the
// lifetime of a Device.
BatchSize() int
}

664
wgstack/tun/tun_linux.go Normal file
View File

@@ -0,0 +1,664 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
/* Implementation of the TUN device interface for linux
*/
import (
"errors"
"fmt"
"os"
"sync"
"syscall"
"time"
"unsafe"
wgconn "github.com/slackhq/nebula/wgstack/conn"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
const (
cloneDevicePath = "/dev/net/tun"
ifReqSize = unix.IFNAMSIZ + 64
)
type NativeTun struct {
tunFile *os.File
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
toWrite []int
tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) routineHackListener() {
defer tun.hackListenerClosed.Unlock()
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
last := 0
const (
up = 1
down = 2
)
for {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return
}
err2 := sysconn.Control(func(fd uintptr) {
_, err = unix.Write(int(fd), nil)
})
if err2 != nil {
return
}
switch err {
case unix.EINVAL:
if last != up {
// If the tunnel is up, it reports that write() is
// allowed but we provided invalid data.
tun.events <- EventUp
last = up
}
case unix.EIO:
if last != down {
// If the tunnel is down, it reports that no I/O
// is possible, without checking our provided data.
tun.events <- EventDown
last = down
}
default:
return
}
select {
case <-time.After(time.Second):
// nothing
case <-tun.statusListenersShutdown:
return
}
}
}
func createNetlinkSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
return -1, err
}
return sock, nil
}
func (tun *NativeTun) routineNetlinkListener() {
defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
select {
case <-tun.statusListenersShutdown:
return
default:
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if int(hdr.Len) > len(remain) {
break
}
switch hdr.Type {
case unix.NLMSG_DONE:
remain = []byte{}
case unix.RTM_NEWLINK:
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
remain = remain[hdr.Len:]
if info.Index != tun.index {
// not our interface
continue
}
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
// Don't emit EventDown before we've ever emitted EventUp.
// This avoids a startup race with HackListener, which
// might detect Up before we have finished reporting Down.
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
default:
remain = remain[hdr.Len:]
}
}
}
}
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFINDEX),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, errno
}
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
}
func (tun *NativeTun) setMTU(n int) error {
name, err := tun.Name()
if err != nil {
return err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return errno
}
return nil
}
func (tun *NativeTun) routineNetlinkRead() {
defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if int(hdr.Len) > len(remain) {
break
}
switch hdr.Type {
case unix.NLMSG_DONE:
remain = []byte{}
case unix.RTM_NEWLINK:
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
remain = remain[hdr.Len:]
if info.Index != tun.index {
continue
}
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
default:
remain = remain[hdr.Len:]
}
}
}
}
func (tun *NativeTun) routineNetlink() {
var err error
tun.netlinkSock, err = createNetlinkSocket()
if err != nil {
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
return
}
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
if err != nil {
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
return
}
go tun.routineNetlinkListener()
}
func (tun *NativeTun) Close() error {
var err1, err2 error
tun.closeOnce.Do(func() {
if tun.statusListenersShutdown != nil {
close(tun.statusListenersShutdown)
if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
}
err2 = tun.tunFile.Close()
})
if err1 != nil {
return err1
}
return err2
}
func (tun *NativeTun) BatchSize() int {
return tun.batchSize
}
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
func (tun *NativeTun) initFromFlags(name string) error {
sc, err := tun.tunFile.SyscallConn()
if err != nil {
return err
}
if e := sc.Control(func(fd uintptr) {
var (
ifr *unix.Ifreq
)
ifr, err = unix.NewIfreq(name)
if err != nil {
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
if err != nil {
return
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = wgconn.IdealBatchSize
} else {
tun.batchSize = 1
}
}); e != nil {
return e
}
return err
}
// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
}
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
tun, err := CreateTUNFromFile(fd, mtu)
if err != nil {
return nil, err
}
if name != "tun" {
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
tun.Close()
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
}
}
return tun, nil
}
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
errors: make(chan error, 5),
events: make(chan Event, 5),
}
name, err := tun.Name()
if err != nil {
return nil, fmt.Errorf("failed to determine TUN name: %w", err)
}
if err := tun.initFromFlags(name); err != nil {
return nil, fmt.Errorf("failed to query TUN flags: %w", err)
}
if tun.batchSize == 0 {
tun.batchSize = 1
}
tun.index, err = getIFIndex(name)
if err != nil {
return nil, fmt.Errorf("failed to get TUN index: %w", err)
}
if err = tun.setMTU(mtu); err != nil {
return nil, fmt.Errorf("failed to set MTU: %w", err)
}
tun.statusListenersShutdown = make(chan struct{})
go tun.routineNetlink()
if tun.batchSize == 0 {
tun.batchSize = 1
}
tun.tcp4GROTable = newTCPGROTable()
tun.tcp6GROTable = newTCPGROTable()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
tun.nameOnce.Do(tun.initNameCache)
return tun.nameCache, tun.nameErr
}
func (tun *NativeTun) initNameCache() {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
tun.nameErr = err
return
}
err = sysconn.Control(func(fd uintptr) {
var ifr [ifReqSize]byte
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
fd,
uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
tun.nameErr = errno
return
}
tun.nameCache = unix.ByteSliceToString(ifr[:])
})
if err != nil && tun.nameErr == nil {
tun.nameErr = err
}
}
func (tun *NativeTun) MTU() (int, error) {
name, err := tun.Name()
if err != nil {
return 0, err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, errno
}
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
tun.tcp4GROTable.reset()
tun.tcp6GROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
errs error
total int
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
for i := range bufs {
tun.toWrite = append(tun.toWrite, i)
}
}
for _, bufsI := range tun.toWrite {
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
errs = errors.Join(errs, err)
} else {
total += n
}
}
return total, errs
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
if err := hdr.decode(in); err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
return 0, err
}
}
if len(in) > len(bufs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
}
n := copy(bufs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.readOpMu.Lock()
defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
readInto := bufs[0][offset:]
if tun.vnetHdr {
readInto = tun.readBuff[:]
}
n, err := tun.tunFile.Read(readInto)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return 0, err
}
if tun.vnetHdr {
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
}
sizes[0] = n
return 1, nil
}
}